Arthur-LAGACHERIE commited on
Commit
f550107
·
verified ·
1 Parent(s): 0dcd6ba

Update modeling_recllama.py

Browse files
Files changed (1) hide show
  1. modeling_recllama.py +52 -26
modeling_recllama.py CHANGED
@@ -45,6 +45,7 @@ class RecLlamaConfig(PretrainedConfig):
45
  coda_layers:int = 2,
46
  mean_recurrence:int = 12,
47
  max_backprop_depth:int = 8,
 
48
  **kwargs
49
  ):
50
  self.vocab_size = vocab_size
@@ -79,6 +80,7 @@ class RecLlamaConfig(PretrainedConfig):
79
  self.coda_layers = coda_layers
80
  self.mean_recurrence = mean_recurrence
81
  self.max_backprop_depth = max_backprop_depth
 
82
  self.auto_map = {"AutoModelForCausalLM": "Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaForCausalLM", "AutoConfig":"Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaConfig"}
83
 
84
  super().__init__(
@@ -89,18 +91,6 @@ class RecLlamaConfig(PretrainedConfig):
89
  **kwargs,
90
  )
91
 
92
-
93
- @dataclass
94
- class CausalLMOutputRecurrentLatents(ModelOutput):
95
- loss: Optional[torch.Tensor] = None
96
- log_ppl: Optional[torch.Tensor] = None
97
- logits: Optional[torch.Tensor] = None
98
- past_key_values: Optional[Cache] = None
99
- latent_states: Optional[torch.Tensor] = None
100
- hidden_states: Optional[torch.Tensor] = None
101
- attention_maps: Optional[dict[int, torch.Tensor]] = None
102
- stats: Optional[dict] = None
103
-
104
 
105
 
106
  class RecDynamicCache(DynamicCache):
@@ -146,7 +136,6 @@ class RecDynamicCache(DynamicCache):
146
  else:
147
  self.key_cache[layer_name] = torch.cat([self.key_cache[layer_name], key_states], dim=-2)
148
  self.value_cache[layer_name] = torch.cat([self.value_cache[layer_name], value_states], dim=-2)
149
-
150
  return self.key_cache[layer_name], self.value_cache[layer_name]
151
 
152
 
@@ -157,7 +146,42 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
157
  self.prelude_layers = config.prelude_layers
158
  self.recurrent_layers = config.recurrent_layers
159
  self.coda_layers = config.coda_layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  @classmethod
162
  def from_llama_model(
163
  cls,
@@ -167,6 +191,7 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
167
  coda_layers: int,
168
  mean_recurrence: int = 4,
169
  max_backprop_depth: int = 6,
 
170
  ) -> "RecLlamaForCausalLM":
171
  """
172
  Convert a regular LlamaForCausalLM model to a RecLlamaForCausalLM model.
@@ -197,13 +222,14 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
197
  config.coda_layers = coda_layers
198
  config.mean_recurrence = mean_recurrence
199
  config.max_backprop_depth = max_backprop_depth
 
200
 
201
  rec_model = cls(config)
202
  rec_model.model.embed_tokens = llama_model.model.embed_tokens
203
  rec_model.model.norm = llama_model.model.norm
204
  rec_model.model.layers = llama_model.model.layers
205
  rec_model.lm_head = llama_model.lm_head
206
-
207
  return rec_model
208
 
209
 
@@ -224,7 +250,7 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
224
  num_steps: int = None,
225
  **kwargs: Unpack[KwargsForCausalLM],
226
  ) -> Union[Tuple, CausalLMOutputWithPast]:
227
-
228
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
229
  output_hidden_states = (
230
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -243,7 +269,7 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
243
  cache_position = torch.arange(
244
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
245
  )
246
-
247
  if position_ids is None:
248
  position_ids = cache_position.unsqueeze(0)
249
 
@@ -275,15 +301,15 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
275
 
276
  # recurrent block
277
  inputs_embeds = self.iterate_forward(
278
- inputs_embeds,
279
- causal_mask,
280
- position_ids,
281
- past_key_values,
282
- output_attentions,
283
- use_cache,
284
- cache_position,
285
- position_embeddings,
286
- num_steps
287
  )
288
 
289
  # coda blocks
@@ -402,7 +428,7 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
402
  mu = math.log(t) - (sigma**2 / 2)
403
  rate = torch.zeros((1,), dtype=torch.float).log_normal_(mean=mu, std=sigma)
404
  n = torch.poisson(rate) + 1 # Corrected Poisson sampling
405
- n = torch.clamp(n, min=0) # Ensure non-negative
406
  k = torch.clamp(n, max=self.config.max_backprop_depth) # Limit k properly
407
  else:
408
  n = torch.tensor(self.config.mean_recurrence, dtype=torch.long)
 
45
  coda_layers:int = 2,
46
  mean_recurrence:int = 12,
47
  max_backprop_depth:int = 8,
48
+ max_recurrence:int = 18,
49
  **kwargs
50
  ):
51
  self.vocab_size = vocab_size
 
80
  self.coda_layers = coda_layers
81
  self.mean_recurrence = mean_recurrence
82
  self.max_backprop_depth = max_backprop_depth
83
+ self.max_recurrence = max_recurrence
84
  self.auto_map = {"AutoModelForCausalLM": "Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaForCausalLM", "AutoConfig":"Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaConfig"}
85
 
86
  super().__init__(
 
91
  **kwargs,
92
  )
93
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
 
96
  class RecDynamicCache(DynamicCache):
 
136
  else:
137
  self.key_cache[layer_name] = torch.cat([self.key_cache[layer_name], key_states], dim=-2)
138
  self.value_cache[layer_name] = torch.cat([self.value_cache[layer_name], value_states], dim=-2)
 
139
  return self.key_cache[layer_name], self.value_cache[layer_name]
140
 
141
 
 
146
  self.prelude_layers = config.prelude_layers
147
  self.recurrent_layers = config.recurrent_layers
148
  self.coda_layers = config.coda_layers
149
+
150
+ for i in range(len(self.model.layers)):
151
+ self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) #nn.Parameter(torch.full((1, self.model.layers[i].self_attn.k_proj.out_features), k_bias_value))
152
+ self.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.q_proj.out_features))
153
+
154
+
155
+ def get_recurrent_params(self):
156
+ recurrent_params = []
157
+
158
+ # Get indices of recurrent layers
159
+ recurrent_start = self.prelude_layers
160
+ recurrent_end = self.prelude_layers + self.recurrent_layers
161
+
162
+ # Extract parameters from recurrent layers
163
+ for layer_idx in range(recurrent_start, recurrent_end):
164
+ layer = self.model.layers[layer_idx]
165
+ for param_name, param in layer.named_parameters():
166
+ recurrent_params.append(param)
167
+
168
+ return sum(p.numel() for p in recurrent_params)
169
+
170
+ def get_param_count(self):
171
+ return sum(p.numel() for p in self.parameters())
172
 
173
+ def add_bias(self, q_bias_value=0.1, k_bias_value=0.1):
174
+ for i in range(len(self.model.layers)):
175
+ self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) #nn.Parameter(torch.full((1, self.model.layers[i].self_attn.k_proj.out_features), k_bias_value))
176
+ self.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.q_proj.out_features))
177
+
178
+ @staticmethod
179
+ def add_bias_to_model(model, q_bias_value=0.1, k_bias_value=0.1):
180
+ for i in range(len(model.model.layers)):
181
+ model.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.zeros(1, model.model.layers[i].self_attn.k_proj.out_features))
182
+ model.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.zeros(1, model.model.layers[i].self_attn.q_proj.out_features))
183
+ return model
184
+
185
  @classmethod
186
  def from_llama_model(
187
  cls,
 
191
  coda_layers: int,
192
  mean_recurrence: int = 4,
193
  max_backprop_depth: int = 6,
194
+ max_recurrence: int = 8,
195
  ) -> "RecLlamaForCausalLM":
196
  """
197
  Convert a regular LlamaForCausalLM model to a RecLlamaForCausalLM model.
 
222
  config.coda_layers = coda_layers
223
  config.mean_recurrence = mean_recurrence
224
  config.max_backprop_depth = max_backprop_depth
225
+ config.max_recurrence = max_recurrence
226
 
227
  rec_model = cls(config)
228
  rec_model.model.embed_tokens = llama_model.model.embed_tokens
229
  rec_model.model.norm = llama_model.model.norm
230
  rec_model.model.layers = llama_model.model.layers
231
  rec_model.lm_head = llama_model.lm_head
232
+ rec_model = RecLlamaForCausalLM.add_bias_to_model(rec_model)
233
  return rec_model
234
 
235
 
 
250
  num_steps: int = None,
251
  **kwargs: Unpack[KwargsForCausalLM],
252
  ) -> Union[Tuple, CausalLMOutputWithPast]:
253
+
254
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
255
  output_hidden_states = (
256
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
269
  cache_position = torch.arange(
270
  past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
271
  )
272
+
273
  if position_ids is None:
274
  position_ids = cache_position.unsqueeze(0)
275
 
 
301
 
302
  # recurrent block
303
  inputs_embeds = self.iterate_forward(
304
+ inputs_embeds=inputs_embeds,
305
+ attention_mask=causal_mask,
306
+ position_ids=position_ids,
307
+ past_key_value=past_key_values,
308
+ output_attentions=output_attentions,
309
+ use_cache=use_cache,
310
+ cache_position=cache_position,
311
+ position_embeddings=position_embeddings,
312
+ num_steps=num_steps
313
  )
314
 
315
  # coda blocks
 
428
  mu = math.log(t) - (sigma**2 / 2)
429
  rate = torch.zeros((1,), dtype=torch.float).log_normal_(mean=mu, std=sigma)
430
  n = torch.poisson(rate) + 1 # Corrected Poisson sampling
431
+ n = torch.clamp(n, min=0, max=self.config.max_recurrence) # Ensure non-negative
432
  k = torch.clamp(n, max=self.config.max_backprop_depth) # Limit k properly
433
  else:
434
  n = torch.tensor(self.config.mean_recurrence, dtype=torch.long)