from transformers.models.llama.modeling_llama import LlamaModel, LlamaConfig, LlamaForCausalLM, KwargsForCausalLM from dataclasses import dataclass from typing import Callable, List, Optional, Tuple, Union, Any, Dict from transformers.cache_utils import Cache, DynamicCache from transformers.utils import ModelOutput import torch import torch.nn as nn import torch.functional as F from transformers.processing_utils import Unpack from transformers.modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, ) from transformers import PretrainedConfig import math class RecLlamaConfig(PretrainedConfig): model_type = "rec_llama" def __init__( self, vocab_size=32000, hidden_size=4096, intermediate_size=11008, num_hidden_layers=32, num_attention_heads=32, num_key_value_heads=None, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, pretraining_tp=1, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, attention_bias=False, attention_dropout=0.0, mlp_bias=False, head_dim=None, prelude_layers:int = 2, recurrent_layers:int = 2, coda_layers:int = 2, mean_recurrence:int = 12, max_backprop_depth:int = 8, max_recurrence:int = 18, **kwargs ): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads # for backward compatibility if num_key_value_heads is None: num_key_value_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.hidden_act = hidden_act self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.pretraining_tp = pretraining_tp self.use_cache = use_cache self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.mlp_bias = mlp_bias self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads # Validate the correctness of rotary position embeddings parameters # BC: if there is a 'type' field, copy it it to 'rope_type'. if self.rope_scaling is not None and "type" in self.rope_scaling: self.rope_scaling["rope_type"] = self.rope_scaling["type"] self.prelude_layers = prelude_layers self.recurrent_layers = recurrent_layers self.coda_layers = coda_layers self.mean_recurrence = mean_recurrence self.max_backprop_depth = max_backprop_depth self.max_recurrence = max_recurrence self.auto_map = {"AutoModelForCausalLM": "Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaForCausalLM", "AutoConfig":"Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaConfig"} super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, ) class RecDynamicCache(DynamicCache): def __init__(self, rec_layers: List[int]) -> None: super().__init__() self._seen_tokens = 0 # Used in generate to keep tally of how many tokens the cache has seen self.rec_layers = rec_layers self.key_cache: Dict[str, torch.Tensor] = {} self.value_cache: Dict[str, torch.Tensor] = {} self.rec_counters = {layer: 0 for layer in rec_layers} def update( self, key_states: torch.Tensor, value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: if layer_idx not in self.rec_layers: # Not a recurrent layer layer_name = f"layer-{layer_idx}" if layer_idx == 0: self._seen_tokens += key_states.shape[-2] # Update the cache if key_states is not None: if layer_name not in self.key_cache: self.key_cache[layer_name] = key_states self.value_cache[layer_name] = value_states else: self.key_cache[layer_name] = torch.cat([self.key_cache[layer_name], key_states], dim=-2) self.value_cache[layer_name] = torch.cat([self.value_cache[layer_name], value_states], dim=-2) else: # Recurrent layer layer_name = f"rec-{layer_idx}-{self.rec_counters[layer_idx]}" self.rec_counters[layer_idx] += 1 # Update the cache for recurrent layers if layer_name not in self.key_cache: self.key_cache[layer_name] = key_states self.value_cache[layer_name] = value_states else: self.key_cache[layer_name] = torch.cat([self.key_cache[layer_name], key_states], dim=-2) self.value_cache[layer_name] = torch.cat([self.value_cache[layer_name], value_states], dim=-2) return self.key_cache[layer_name], self.value_cache[layer_name] class RecLlamaForCausalLM(LlamaForCausalLM): config_class = RecLlamaConfig def __init__(self, config: RecLlamaConfig, num_steps=None): super().__init__(config) self.prelude_layers = config.prelude_layers self.recurrent_layers = config.recurrent_layers self.coda_layers = config.coda_layers self.num_steps = num_steps for i in range(len(self.model.layers)): self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) #nn.Parameter(torch.full((1, self.model.layers[i].self_attn.k_proj.out_features), k_bias_value)) self.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.q_proj.out_features)) def get_recurrent_params(self): recurrent_params = [] # Get indices of recurrent layers recurrent_start = self.prelude_layers recurrent_end = self.prelude_layers + self.recurrent_layers # Extract parameters from recurrent layers for layer_idx in range(recurrent_start, recurrent_end): layer = self.model.layers[layer_idx] for param_name, param in layer.named_parameters(): recurrent_params.append(param) return sum(p.numel() for p in recurrent_params) def get_param_count(self): return sum(p.numel() for p in self.parameters()) def add_bias(self, q_bias_value=0.1, k_bias_value=0.1): for i in range(len(self.model.layers)): self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) #nn.Parameter(torch.full((1, self.model.layers[i].self_attn.k_proj.out_features), k_bias_value)) self.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.q_proj.out_features)) @staticmethod def add_bias_to_model(model, q_bias_value=0.1, k_bias_value=0.1): for i in range(len(model.model.layers)): model.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.zeros(1, model.model.layers[i].self_attn.k_proj.out_features)) model.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.zeros(1, model.model.layers[i].self_attn.q_proj.out_features)) return model @classmethod def from_llama_model( cls, llama_model: LlamaForCausalLM, prelude_layers: int, recurrent_layers: int, coda_layers: int, mean_recurrence: int = 4, max_backprop_depth: int = 6, max_recurrence: int = 8, ) -> "RecLlamaForCausalLM": """ Convert a regular LlamaForCausalLM model to a RecLlamaForCausalLM model. Args: llama_model: The source LlamaForCausalLM model prelude_layers: Number of non-recurrent layers at the start recurrent_layers: Number of recurrent layers in the middle coda_layers: Number of non-recurrent layers at the end mean_recurrence: Average number of recurrent iterations (default: 1) max_backprop_depth: Maximum number of iterations to backpropagate through (default: 1) Returns: A RecLlamaForCausalLM model with weights copied from the source model """ # Validate layer counts total_layers = len(llama_model.model.layers) if prelude_layers + recurrent_layers + coda_layers != total_layers: raise ValueError( f"Sum of layers ({prelude_layers + recurrent_layers + coda_layers}) " f"must equal total number of model layers ({total_layers})" ) # Create new config based on original model's config config = RecLlamaConfig(**llama_model.config.to_dict()) config.prelude_layers = prelude_layers config.recurrent_layers = recurrent_layers config.coda_layers = coda_layers config.mean_recurrence = mean_recurrence config.max_backprop_depth = max_backprop_depth config.max_recurrence = max_recurrence rec_model = cls(config) rec_model.model.embed_tokens = llama_model.model.embed_tokens rec_model.model.norm = llama_model.model.norm rec_model.model.layers = llama_model.model.layers rec_model.lm_head = llama_model.lm_head rec_model = RecLlamaForCausalLM.add_bias_to_model(rec_model) return rec_model def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, num_steps: int = None, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict inputs_embeds = self.model.embed_tokens(input_ids) if use_cache and past_key_values is None: recurrent_layers = list(range(self.prelude_layers, self.prelude_layers+self.recurrent_layers)) past_key_values = RecDynamicCache(recurrent_layers) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self.model._update_causal_mask( attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ) position_embeddings = self.model.rotary_emb(inputs_embeds, position_ids) # run non-recurrent blocks (prelude) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None for block_idx, block in enumerate(self.model.layers[:self.prelude_layers]): layer_outputs = block( inputs_embeds, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) inputs_embeds = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) # recurrent block inputs_embeds = self.iterate_forward( inputs_embeds=inputs_embeds, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, num_steps=num_steps ) # coda blocks for block_idx, block in enumerate(self.model.layers[self.prelude_layers+self.recurrent_layers : self.prelude_layers+self.recurrent_layers+self.coda_layers]): layer_outputs = block( inputs_embeds, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) inputs_embeds = layer_outputs[0] if output_attentions: all_self_attns += (layer_outputs[1],) inputs_embeds = self.model.norm(inputs_embeds) if output_hidden_states: all_hidden_states += (inputs_embeds,) outputs = BaseModelOutputWithPast( last_hidden_state=inputs_embeds, past_key_values=past_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(inputs_embeds[:, slice_indices, :]) loss = None if labels is not None: loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.model.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + outputs if loss is not None else outputs return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) @torch._dynamo.disable(recursive=False) # type: ignore def iterate_forward( self, inputs_embeds, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, num_steps=None, ): if num_steps is None and self.num_steps is None: num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore elif hasattr(num_steps, "__len__") and len(num_steps) > 1: num_steps_no_grad, num_steps_with_grad = num_steps elif self.num_steps is not None: num_steps_no_grad, num_steps_with_grad = self.num_steps, self.num_steps else: num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0) with torch.no_grad(): # ultra annoying in ddp due to # https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594 # for now running with find_unused_params=True enabled even though the graph structure is (technically) clear # and all parameters are always used for step in range(num_steps_no_grad): for block_idx, block in enumerate(self.model.layers[self.prelude_layers:self.prelude_layers+self.recurrent_layers]): layer_output = block( inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) inputs_embeds = layer_output[0] for step in range(num_steps_with_grad): for block_idx, block in enumerate(self.model.layers[self.prelude_layers:self.prelude_layers+self.recurrent_layers]): layer_output = block( inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, ) inputs_embeds = layer_output[0] return inputs_embeds @torch._dynamo.disable(recursive=False) # type: ignore def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]: """Outputs are long tensors so that they can be passed through compiled functions""" t = max(self.config.mean_recurrence, 0) if self.training: sigma = 0.5 mu = math.log(t) - (sigma**2 / 2) rate = torch.zeros((1,), dtype=torch.float).log_normal_(mean=mu, std=sigma) n = torch.poisson(rate) + 1 # Corrected Poisson sampling n = torch.clamp(n, min=0, max=self.config.max_recurrence) # Ensure non-negative k = torch.clamp(n, max=self.config.max_backprop_depth) # Limit k properly else: n = torch.tensor(self.config.mean_recurrence, dtype=torch.long) k = torch.tensor(0, dtype=torch.long) return n.to(dtype=torch.long), k.to(dtype=torch.long) @torch.no_grad() def generate(self, *args, **kwargs): return super().generate(*args, **kwargs)