Update modeling_recllama.py
Browse files- modeling_recllama.py +52 -26
modeling_recllama.py
CHANGED
@@ -45,6 +45,7 @@ class RecLlamaConfig(PretrainedConfig):
|
|
45 |
coda_layers:int = 2,
|
46 |
mean_recurrence:int = 12,
|
47 |
max_backprop_depth:int = 8,
|
|
|
48 |
**kwargs
|
49 |
):
|
50 |
self.vocab_size = vocab_size
|
@@ -79,6 +80,7 @@ class RecLlamaConfig(PretrainedConfig):
|
|
79 |
self.coda_layers = coda_layers
|
80 |
self.mean_recurrence = mean_recurrence
|
81 |
self.max_backprop_depth = max_backprop_depth
|
|
|
82 |
self.auto_map = {"AutoModelForCausalLM": "Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaForCausalLM", "AutoConfig":"Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaConfig"}
|
83 |
|
84 |
super().__init__(
|
@@ -89,18 +91,6 @@ class RecLlamaConfig(PretrainedConfig):
|
|
89 |
**kwargs,
|
90 |
)
|
91 |
|
92 |
-
|
93 |
-
@dataclass
|
94 |
-
class CausalLMOutputRecurrentLatents(ModelOutput):
|
95 |
-
loss: Optional[torch.Tensor] = None
|
96 |
-
log_ppl: Optional[torch.Tensor] = None
|
97 |
-
logits: Optional[torch.Tensor] = None
|
98 |
-
past_key_values: Optional[Cache] = None
|
99 |
-
latent_states: Optional[torch.Tensor] = None
|
100 |
-
hidden_states: Optional[torch.Tensor] = None
|
101 |
-
attention_maps: Optional[dict[int, torch.Tensor]] = None
|
102 |
-
stats: Optional[dict] = None
|
103 |
-
|
104 |
|
105 |
|
106 |
class RecDynamicCache(DynamicCache):
|
@@ -146,7 +136,6 @@ class RecDynamicCache(DynamicCache):
|
|
146 |
else:
|
147 |
self.key_cache[layer_name] = torch.cat([self.key_cache[layer_name], key_states], dim=-2)
|
148 |
self.value_cache[layer_name] = torch.cat([self.value_cache[layer_name], value_states], dim=-2)
|
149 |
-
|
150 |
return self.key_cache[layer_name], self.value_cache[layer_name]
|
151 |
|
152 |
|
@@ -157,7 +146,42 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
|
|
157 |
self.prelude_layers = config.prelude_layers
|
158 |
self.recurrent_layers = config.recurrent_layers
|
159 |
self.coda_layers = config.coda_layers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
@classmethod
|
162 |
def from_llama_model(
|
163 |
cls,
|
@@ -167,6 +191,7 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
|
|
167 |
coda_layers: int,
|
168 |
mean_recurrence: int = 4,
|
169 |
max_backprop_depth: int = 6,
|
|
|
170 |
) -> "RecLlamaForCausalLM":
|
171 |
"""
|
172 |
Convert a regular LlamaForCausalLM model to a RecLlamaForCausalLM model.
|
@@ -197,13 +222,14 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
|
|
197 |
config.coda_layers = coda_layers
|
198 |
config.mean_recurrence = mean_recurrence
|
199 |
config.max_backprop_depth = max_backprop_depth
|
|
|
200 |
|
201 |
rec_model = cls(config)
|
202 |
rec_model.model.embed_tokens = llama_model.model.embed_tokens
|
203 |
rec_model.model.norm = llama_model.model.norm
|
204 |
rec_model.model.layers = llama_model.model.layers
|
205 |
rec_model.lm_head = llama_model.lm_head
|
206 |
-
|
207 |
return rec_model
|
208 |
|
209 |
|
@@ -224,7 +250,7 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
|
|
224 |
num_steps: int = None,
|
225 |
**kwargs: Unpack[KwargsForCausalLM],
|
226 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
227 |
-
|
228 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
229 |
output_hidden_states = (
|
230 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -243,7 +269,7 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
|
|
243 |
cache_position = torch.arange(
|
244 |
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
245 |
)
|
246 |
-
|
247 |
if position_ids is None:
|
248 |
position_ids = cache_position.unsqueeze(0)
|
249 |
|
@@ -275,15 +301,15 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
|
|
275 |
|
276 |
# recurrent block
|
277 |
inputs_embeds = self.iterate_forward(
|
278 |
-
inputs_embeds,
|
279 |
-
causal_mask,
|
280 |
-
position_ids,
|
281 |
-
past_key_values,
|
282 |
-
output_attentions,
|
283 |
-
use_cache,
|
284 |
-
cache_position,
|
285 |
-
position_embeddings,
|
286 |
-
num_steps
|
287 |
)
|
288 |
|
289 |
# coda blocks
|
@@ -402,7 +428,7 @@ class RecLlamaForCausalLM(LlamaForCausalLM):
|
|
402 |
mu = math.log(t) - (sigma**2 / 2)
|
403 |
rate = torch.zeros((1,), dtype=torch.float).log_normal_(mean=mu, std=sigma)
|
404 |
n = torch.poisson(rate) + 1 # Corrected Poisson sampling
|
405 |
-
n = torch.clamp(n, min=0) # Ensure non-negative
|
406 |
k = torch.clamp(n, max=self.config.max_backprop_depth) # Limit k properly
|
407 |
else:
|
408 |
n = torch.tensor(self.config.mean_recurrence, dtype=torch.long)
|
|
|
45 |
coda_layers:int = 2,
|
46 |
mean_recurrence:int = 12,
|
47 |
max_backprop_depth:int = 8,
|
48 |
+
max_recurrence:int = 18,
|
49 |
**kwargs
|
50 |
):
|
51 |
self.vocab_size = vocab_size
|
|
|
80 |
self.coda_layers = coda_layers
|
81 |
self.mean_recurrence = mean_recurrence
|
82 |
self.max_backprop_depth = max_backprop_depth
|
83 |
+
self.max_recurrence = max_recurrence
|
84 |
self.auto_map = {"AutoModelForCausalLM": "Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaForCausalLM", "AutoConfig":"Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaConfig"}
|
85 |
|
86 |
super().__init__(
|
|
|
91 |
**kwargs,
|
92 |
)
|
93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
|
96 |
class RecDynamicCache(DynamicCache):
|
|
|
136 |
else:
|
137 |
self.key_cache[layer_name] = torch.cat([self.key_cache[layer_name], key_states], dim=-2)
|
138 |
self.value_cache[layer_name] = torch.cat([self.value_cache[layer_name], value_states], dim=-2)
|
|
|
139 |
return self.key_cache[layer_name], self.value_cache[layer_name]
|
140 |
|
141 |
|
|
|
146 |
self.prelude_layers = config.prelude_layers
|
147 |
self.recurrent_layers = config.recurrent_layers
|
148 |
self.coda_layers = config.coda_layers
|
149 |
+
|
150 |
+
for i in range(len(self.model.layers)):
|
151 |
+
self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) #nn.Parameter(torch.full((1, self.model.layers[i].self_attn.k_proj.out_features), k_bias_value))
|
152 |
+
self.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.q_proj.out_features))
|
153 |
+
|
154 |
+
|
155 |
+
def get_recurrent_params(self):
|
156 |
+
recurrent_params = []
|
157 |
+
|
158 |
+
# Get indices of recurrent layers
|
159 |
+
recurrent_start = self.prelude_layers
|
160 |
+
recurrent_end = self.prelude_layers + self.recurrent_layers
|
161 |
+
|
162 |
+
# Extract parameters from recurrent layers
|
163 |
+
for layer_idx in range(recurrent_start, recurrent_end):
|
164 |
+
layer = self.model.layers[layer_idx]
|
165 |
+
for param_name, param in layer.named_parameters():
|
166 |
+
recurrent_params.append(param)
|
167 |
+
|
168 |
+
return sum(p.numel() for p in recurrent_params)
|
169 |
+
|
170 |
+
def get_param_count(self):
|
171 |
+
return sum(p.numel() for p in self.parameters())
|
172 |
|
173 |
+
def add_bias(self, q_bias_value=0.1, k_bias_value=0.1):
|
174 |
+
for i in range(len(self.model.layers)):
|
175 |
+
self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) #nn.Parameter(torch.full((1, self.model.layers[i].self_attn.k_proj.out_features), k_bias_value))
|
176 |
+
self.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.q_proj.out_features))
|
177 |
+
|
178 |
+
@staticmethod
|
179 |
+
def add_bias_to_model(model, q_bias_value=0.1, k_bias_value=0.1):
|
180 |
+
for i in range(len(model.model.layers)):
|
181 |
+
model.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.zeros(1, model.model.layers[i].self_attn.k_proj.out_features))
|
182 |
+
model.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.zeros(1, model.model.layers[i].self_attn.q_proj.out_features))
|
183 |
+
return model
|
184 |
+
|
185 |
@classmethod
|
186 |
def from_llama_model(
|
187 |
cls,
|
|
|
191 |
coda_layers: int,
|
192 |
mean_recurrence: int = 4,
|
193 |
max_backprop_depth: int = 6,
|
194 |
+
max_recurrence: int = 8,
|
195 |
) -> "RecLlamaForCausalLM":
|
196 |
"""
|
197 |
Convert a regular LlamaForCausalLM model to a RecLlamaForCausalLM model.
|
|
|
222 |
config.coda_layers = coda_layers
|
223 |
config.mean_recurrence = mean_recurrence
|
224 |
config.max_backprop_depth = max_backprop_depth
|
225 |
+
config.max_recurrence = max_recurrence
|
226 |
|
227 |
rec_model = cls(config)
|
228 |
rec_model.model.embed_tokens = llama_model.model.embed_tokens
|
229 |
rec_model.model.norm = llama_model.model.norm
|
230 |
rec_model.model.layers = llama_model.model.layers
|
231 |
rec_model.lm_head = llama_model.lm_head
|
232 |
+
rec_model = RecLlamaForCausalLM.add_bias_to_model(rec_model)
|
233 |
return rec_model
|
234 |
|
235 |
|
|
|
250 |
num_steps: int = None,
|
251 |
**kwargs: Unpack[KwargsForCausalLM],
|
252 |
) -> Union[Tuple, CausalLMOutputWithPast]:
|
253 |
+
|
254 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
255 |
output_hidden_states = (
|
256 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
269 |
cache_position = torch.arange(
|
270 |
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
271 |
)
|
272 |
+
|
273 |
if position_ids is None:
|
274 |
position_ids = cache_position.unsqueeze(0)
|
275 |
|
|
|
301 |
|
302 |
# recurrent block
|
303 |
inputs_embeds = self.iterate_forward(
|
304 |
+
inputs_embeds=inputs_embeds,
|
305 |
+
attention_mask=causal_mask,
|
306 |
+
position_ids=position_ids,
|
307 |
+
past_key_value=past_key_values,
|
308 |
+
output_attentions=output_attentions,
|
309 |
+
use_cache=use_cache,
|
310 |
+
cache_position=cache_position,
|
311 |
+
position_embeddings=position_embeddings,
|
312 |
+
num_steps=num_steps
|
313 |
)
|
314 |
|
315 |
# coda blocks
|
|
|
428 |
mu = math.log(t) - (sigma**2 / 2)
|
429 |
rate = torch.zeros((1,), dtype=torch.float).log_normal_(mean=mu, std=sigma)
|
430 |
n = torch.poisson(rate) + 1 # Corrected Poisson sampling
|
431 |
+
n = torch.clamp(n, min=0, max=self.config.max_recurrence) # Ensure non-negative
|
432 |
k = torch.clamp(n, max=self.config.max_backprop_depth) # Limit k properly
|
433 |
else:
|
434 |
n = torch.tensor(self.config.mean_recurrence, dtype=torch.long)
|