|
from transformers.models.llama.modeling_llama import LlamaModel, LlamaConfig, LlamaForCausalLM, KwargsForCausalLM |
|
from dataclasses import dataclass |
|
from typing import Callable, List, Optional, Tuple, Union, Any, Dict |
|
from transformers.cache_utils import Cache, DynamicCache |
|
from transformers.utils import ModelOutput |
|
import torch |
|
import torch.nn as nn |
|
import torch.functional as F |
|
from transformers.processing_utils import Unpack |
|
from transformers.modeling_outputs import ( |
|
BaseModelOutputWithPast, |
|
CausalLMOutputWithPast, |
|
) |
|
from transformers import PretrainedConfig |
|
import math |
|
|
|
class RecLlamaConfig(PretrainedConfig): |
|
model_type = "rec_llama" |
|
def __init__( |
|
self, |
|
vocab_size=32000, |
|
hidden_size=4096, |
|
intermediate_size=11008, |
|
num_hidden_layers=32, |
|
num_attention_heads=32, |
|
num_key_value_heads=None, |
|
hidden_act="silu", |
|
max_position_embeddings=2048, |
|
initializer_range=0.02, |
|
rms_norm_eps=1e-6, |
|
use_cache=True, |
|
pad_token_id=None, |
|
bos_token_id=1, |
|
eos_token_id=2, |
|
pretraining_tp=1, |
|
tie_word_embeddings=False, |
|
rope_theta=10000.0, |
|
rope_scaling=None, |
|
attention_bias=False, |
|
attention_dropout=0.0, |
|
mlp_bias=False, |
|
head_dim=None, |
|
prelude_layers:int = 2, |
|
recurrent_layers:int = 2, |
|
coda_layers:int = 2, |
|
mean_recurrence:int = 12, |
|
max_backprop_depth:int = 8, |
|
max_recurrence:int = 18, |
|
**kwargs |
|
): |
|
self.vocab_size = vocab_size |
|
self.max_position_embeddings = max_position_embeddings |
|
self.hidden_size = hidden_size |
|
self.intermediate_size = intermediate_size |
|
self.num_hidden_layers = num_hidden_layers |
|
self.num_attention_heads = num_attention_heads |
|
|
|
|
|
if num_key_value_heads is None: |
|
num_key_value_heads = num_attention_heads |
|
|
|
self.num_key_value_heads = num_key_value_heads |
|
self.hidden_act = hidden_act |
|
self.initializer_range = initializer_range |
|
self.rms_norm_eps = rms_norm_eps |
|
self.pretraining_tp = pretraining_tp |
|
self.use_cache = use_cache |
|
self.rope_theta = rope_theta |
|
self.rope_scaling = rope_scaling |
|
self.attention_bias = attention_bias |
|
self.attention_dropout = attention_dropout |
|
self.mlp_bias = mlp_bias |
|
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads |
|
|
|
|
|
if self.rope_scaling is not None and "type" in self.rope_scaling: |
|
self.rope_scaling["rope_type"] = self.rope_scaling["type"] |
|
self.prelude_layers = prelude_layers |
|
self.recurrent_layers = recurrent_layers |
|
self.coda_layers = coda_layers |
|
self.mean_recurrence = mean_recurrence |
|
self.max_backprop_depth = max_backprop_depth |
|
self.max_recurrence = max_recurrence |
|
self.auto_map = {"AutoModelForCausalLM": "Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaForCausalLM", "AutoConfig":"Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaConfig"} |
|
|
|
super().__init__( |
|
pad_token_id=pad_token_id, |
|
bos_token_id=bos_token_id, |
|
eos_token_id=eos_token_id, |
|
tie_word_embeddings=tie_word_embeddings, |
|
**kwargs, |
|
) |
|
|
|
|
|
|
|
class RecDynamicCache(DynamicCache): |
|
def __init__(self, rec_layers: List[int]) -> None: |
|
super().__init__() |
|
self._seen_tokens = 0 |
|
self.rec_layers = rec_layers |
|
self.key_cache: Dict[str, torch.Tensor] = {} |
|
self.value_cache: Dict[str, torch.Tensor] = {} |
|
self.rec_counters = {layer: 0 for layer in rec_layers} |
|
|
|
def update( |
|
self, |
|
key_states: torch.Tensor, |
|
value_states: torch.Tensor, |
|
layer_idx: int, |
|
cache_kwargs: Optional[Dict[str, Any]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
if layer_idx not in self.rec_layers: |
|
|
|
layer_name = f"layer-{layer_idx}" |
|
if layer_idx == 0: |
|
self._seen_tokens += key_states.shape[-2] |
|
|
|
|
|
if key_states is not None: |
|
if layer_name not in self.key_cache: |
|
self.key_cache[layer_name] = key_states |
|
self.value_cache[layer_name] = value_states |
|
else: |
|
self.key_cache[layer_name] = torch.cat([self.key_cache[layer_name], key_states], dim=-2) |
|
self.value_cache[layer_name] = torch.cat([self.value_cache[layer_name], value_states], dim=-2) |
|
else: |
|
|
|
layer_name = f"rec-{layer_idx}-{self.rec_counters[layer_idx]}" |
|
self.rec_counters[layer_idx] += 1 |
|
|
|
|
|
if layer_name not in self.key_cache: |
|
self.key_cache[layer_name] = key_states |
|
self.value_cache[layer_name] = value_states |
|
else: |
|
self.key_cache[layer_name] = torch.cat([self.key_cache[layer_name], key_states], dim=-2) |
|
self.value_cache[layer_name] = torch.cat([self.value_cache[layer_name], value_states], dim=-2) |
|
return self.key_cache[layer_name], self.value_cache[layer_name] |
|
|
|
|
|
class RecLlamaForCausalLM(LlamaForCausalLM): |
|
config_class = RecLlamaConfig |
|
def __init__(self, config: RecLlamaConfig, num_steps=None): |
|
super().__init__(config) |
|
self.prelude_layers = config.prelude_layers |
|
self.recurrent_layers = config.recurrent_layers |
|
self.coda_layers = config.coda_layers |
|
self.num_steps = num_steps |
|
|
|
for i in range(len(self.model.layers)): |
|
self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) |
|
self.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.q_proj.out_features)) |
|
|
|
|
|
def get_recurrent_params(self): |
|
recurrent_params = [] |
|
|
|
|
|
recurrent_start = self.prelude_layers |
|
recurrent_end = self.prelude_layers + self.recurrent_layers |
|
|
|
|
|
for layer_idx in range(recurrent_start, recurrent_end): |
|
layer = self.model.layers[layer_idx] |
|
for param_name, param in layer.named_parameters(): |
|
recurrent_params.append(param) |
|
|
|
return sum(p.numel() for p in recurrent_params) |
|
|
|
def get_param_count(self): |
|
return sum(p.numel() for p in self.parameters()) |
|
|
|
def add_bias(self, q_bias_value=0.1, k_bias_value=0.1): |
|
for i in range(len(self.model.layers)): |
|
self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) |
|
self.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.q_proj.out_features)) |
|
|
|
@staticmethod |
|
def add_bias_to_model(model, q_bias_value=0.1, k_bias_value=0.1): |
|
for i in range(len(model.model.layers)): |
|
model.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.zeros(1, model.model.layers[i].self_attn.k_proj.out_features)) |
|
model.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.zeros(1, model.model.layers[i].self_attn.q_proj.out_features)) |
|
return model |
|
|
|
@classmethod |
|
def from_llama_model( |
|
cls, |
|
llama_model: LlamaForCausalLM, |
|
prelude_layers: int, |
|
recurrent_layers: int, |
|
coda_layers: int, |
|
mean_recurrence: int = 4, |
|
max_backprop_depth: int = 6, |
|
max_recurrence: int = 8, |
|
) -> "RecLlamaForCausalLM": |
|
""" |
|
Convert a regular LlamaForCausalLM model to a RecLlamaForCausalLM model. |
|
|
|
Args: |
|
llama_model: The source LlamaForCausalLM model |
|
prelude_layers: Number of non-recurrent layers at the start |
|
recurrent_layers: Number of recurrent layers in the middle |
|
coda_layers: Number of non-recurrent layers at the end |
|
mean_recurrence: Average number of recurrent iterations (default: 1) |
|
max_backprop_depth: Maximum number of iterations to backpropagate through (default: 1) |
|
|
|
Returns: |
|
A RecLlamaForCausalLM model with weights copied from the source model |
|
""" |
|
|
|
total_layers = len(llama_model.model.layers) |
|
if prelude_layers + recurrent_layers + coda_layers != total_layers: |
|
raise ValueError( |
|
f"Sum of layers ({prelude_layers + recurrent_layers + coda_layers}) " |
|
f"must equal total number of model layers ({total_layers})" |
|
) |
|
|
|
|
|
config = RecLlamaConfig(**llama_model.config.to_dict()) |
|
config.prelude_layers = prelude_layers |
|
config.recurrent_layers = recurrent_layers |
|
config.coda_layers = coda_layers |
|
config.mean_recurrence = mean_recurrence |
|
config.max_backprop_depth = max_backprop_depth |
|
config.max_recurrence = max_recurrence |
|
|
|
rec_model = cls(config) |
|
rec_model.model.embed_tokens = llama_model.model.embed_tokens |
|
rec_model.model.norm = llama_model.model.norm |
|
rec_model.model.layers = llama_model.model.layers |
|
rec_model.lm_head = llama_model.lm_head |
|
rec_model = RecLlamaForCausalLM.add_bias_to_model(rec_model) |
|
return rec_model |
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
logits_to_keep: Union[int, torch.Tensor] = 0, |
|
num_steps: int = None, |
|
**kwargs: Unpack[KwargsForCausalLM], |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
inputs_embeds = self.model.embed_tokens(input_ids) |
|
|
|
if use_cache and past_key_values is None: |
|
recurrent_layers = list(range(self.prelude_layers, self.prelude_layers+self.recurrent_layers)) |
|
past_key_values = RecDynamicCache(recurrent_layers) |
|
|
|
if cache_position is None: |
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
) |
|
|
|
if position_ids is None: |
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
causal_mask = self.model._update_causal_mask( |
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
|
) |
|
|
|
position_embeddings = self.model.rotary_emb(inputs_embeds, position_ids) |
|
|
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
|
|
for block_idx, block in enumerate(self.model.layers[:self.prelude_layers]): |
|
layer_outputs = block( |
|
inputs_embeds, |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
) |
|
inputs_embeds = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
|
|
inputs_embeds = self.iterate_forward( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
num_steps=num_steps |
|
) |
|
|
|
|
|
for block_idx, block in enumerate(self.model.layers[self.prelude_layers+self.recurrent_layers : self.prelude_layers+self.recurrent_layers+self.coda_layers]): |
|
layer_outputs = block( |
|
inputs_embeds, |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
) |
|
inputs_embeds = layer_outputs[0] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
inputs_embeds = self.model.norm(inputs_embeds) |
|
|
|
if output_hidden_states: |
|
all_hidden_states += (inputs_embeds,) |
|
|
|
outputs = BaseModelOutputWithPast( |
|
last_hidden_state=inputs_embeds, |
|
past_key_values=past_key_values if use_cache else None, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
|
logits = self.lm_head(inputs_embeds[:, slice_indices, :]) |
|
|
|
loss = None |
|
if labels is not None: |
|
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.model.config.vocab_size, **kwargs) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + outputs if loss is not None else outputs |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
@torch._dynamo.disable(recursive=False) |
|
def iterate_forward( |
|
self, |
|
inputs_embeds, |
|
attention_mask, |
|
position_ids, |
|
past_key_value, |
|
output_attentions, |
|
use_cache, |
|
cache_position, |
|
position_embeddings, |
|
num_steps=None, |
|
): |
|
if num_steps is None and self.num_steps is None: |
|
num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() |
|
elif hasattr(num_steps, "__len__") and len(num_steps) > 1: |
|
num_steps_no_grad, num_steps_with_grad = num_steps |
|
elif self.num_steps is not None: |
|
num_steps_no_grad, num_steps_with_grad = self.num_steps, self.num_steps |
|
else: |
|
num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0) |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
for step in range(num_steps_no_grad): |
|
for block_idx, block in enumerate(self.model.layers[self.prelude_layers:self.prelude_layers+self.recurrent_layers]): |
|
|
|
layer_output = block( |
|
inputs_embeds, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
) |
|
inputs_embeds = layer_output[0] |
|
|
|
|
|
for step in range(num_steps_with_grad): |
|
for block_idx, block in enumerate(self.model.layers[self.prelude_layers:self.prelude_layers+self.recurrent_layers]): |
|
layer_output = block( |
|
inputs_embeds, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_value, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
) |
|
inputs_embeds = layer_output[0] |
|
|
|
return inputs_embeds |
|
|
|
|
|
@torch._dynamo.disable(recursive=False) |
|
def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]: |
|
"""Outputs are long tensors so that they can be passed through compiled functions""" |
|
t = max(self.config.mean_recurrence, 0) |
|
if self.training: |
|
sigma = 0.5 |
|
mu = math.log(t) - (sigma**2 / 2) |
|
rate = torch.zeros((1,), dtype=torch.float).log_normal_(mean=mu, std=sigma) |
|
n = torch.poisson(rate) + 1 |
|
n = torch.clamp(n, min=0, max=self.config.max_recurrence) |
|
k = torch.clamp(n, max=self.config.max_backprop_depth) |
|
else: |
|
n = torch.tensor(self.config.mean_recurrence, dtype=torch.long) |
|
k = torch.tensor(0, dtype=torch.long) |
|
|
|
return n.to(dtype=torch.long), k.to(dtype=torch.long) |
|
|
|
@torch.no_grad() |
|
def generate(self, *args, **kwargs): |
|
return super().generate(*args, **kwargs) |