RecLlama-code / modeling_recllama.py
Arthur-LAGACHERIE's picture
Update modeling_recllama.py
9dea54a verified
from transformers.models.llama.modeling_llama import LlamaModel, LlamaConfig, LlamaForCausalLM, KwargsForCausalLM
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union, Any, Dict
from transformers.cache_utils import Cache, DynamicCache
from transformers.utils import ModelOutput
import torch
import torch.nn as nn
import torch.functional as F
from transformers.processing_utils import Unpack
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers import PretrainedConfig
import math
class RecLlamaConfig(PretrainedConfig):
model_type = "rec_llama"
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
mlp_bias=False,
head_dim=None,
prelude_layers:int = 2,
recurrent_layers:int = 2,
coda_layers:int = 2,
mean_recurrence:int = 12,
max_backprop_depth:int = 8,
max_recurrence:int = 18,
**kwargs
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mlp_bias = mlp_bias
self.head_dim = head_dim if head_dim is not None else self.hidden_size // self.num_attention_heads
# Validate the correctness of rotary position embeddings parameters
# BC: if there is a 'type' field, copy it it to 'rope_type'.
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
self.prelude_layers = prelude_layers
self.recurrent_layers = recurrent_layers
self.coda_layers = coda_layers
self.mean_recurrence = mean_recurrence
self.max_backprop_depth = max_backprop_depth
self.max_recurrence = max_recurrence
self.auto_map = {"AutoModelForCausalLM": "Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaForCausalLM", "AutoConfig":"Arthur-LAGACHERIE/RecLlama-code--modeling_recllama.RecLlamaConfig"}
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class RecDynamicCache(DynamicCache):
def __init__(self, rec_layers: List[int]) -> None:
super().__init__()
self._seen_tokens = 0 # Used in generate to keep tally of how many tokens the cache has seen
self.rec_layers = rec_layers
self.key_cache: Dict[str, torch.Tensor] = {}
self.value_cache: Dict[str, torch.Tensor] = {}
self.rec_counters = {layer: 0 for layer in rec_layers}
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if layer_idx not in self.rec_layers:
# Not a recurrent layer
layer_name = f"layer-{layer_idx}"
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if key_states is not None:
if layer_name not in self.key_cache:
self.key_cache[layer_name] = key_states
self.value_cache[layer_name] = value_states
else:
self.key_cache[layer_name] = torch.cat([self.key_cache[layer_name], key_states], dim=-2)
self.value_cache[layer_name] = torch.cat([self.value_cache[layer_name], value_states], dim=-2)
else:
# Recurrent layer
layer_name = f"rec-{layer_idx}-{self.rec_counters[layer_idx]}"
self.rec_counters[layer_idx] += 1
# Update the cache for recurrent layers
if layer_name not in self.key_cache:
self.key_cache[layer_name] = key_states
self.value_cache[layer_name] = value_states
else:
self.key_cache[layer_name] = torch.cat([self.key_cache[layer_name], key_states], dim=-2)
self.value_cache[layer_name] = torch.cat([self.value_cache[layer_name], value_states], dim=-2)
return self.key_cache[layer_name], self.value_cache[layer_name]
class RecLlamaForCausalLM(LlamaForCausalLM):
config_class = RecLlamaConfig
def __init__(self, config: RecLlamaConfig, num_steps=None):
super().__init__(config)
self.prelude_layers = config.prelude_layers
self.recurrent_layers = config.recurrent_layers
self.coda_layers = config.coda_layers
self.num_steps = num_steps
for i in range(len(self.model.layers)):
self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) #nn.Parameter(torch.full((1, self.model.layers[i].self_attn.k_proj.out_features), k_bias_value))
self.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.q_proj.out_features))
def get_recurrent_params(self):
recurrent_params = []
# Get indices of recurrent layers
recurrent_start = self.prelude_layers
recurrent_end = self.prelude_layers + self.recurrent_layers
# Extract parameters from recurrent layers
for layer_idx in range(recurrent_start, recurrent_end):
layer = self.model.layers[layer_idx]
for param_name, param in layer.named_parameters():
recurrent_params.append(param)
return sum(p.numel() for p in recurrent_params)
def get_param_count(self):
return sum(p.numel() for p in self.parameters())
def add_bias(self, q_bias_value=0.1, k_bias_value=0.1):
for i in range(len(self.model.layers)):
self.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.k_proj.out_features)) #nn.Parameter(torch.full((1, self.model.layers[i].self_attn.k_proj.out_features), k_bias_value))
self.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.randn(1, self.model.layers[i].self_attn.q_proj.out_features))
@staticmethod
def add_bias_to_model(model, q_bias_value=0.1, k_bias_value=0.1):
for i in range(len(model.model.layers)):
model.model.layers[i].self_attn.k_proj.bias = nn.Parameter(torch.zeros(1, model.model.layers[i].self_attn.k_proj.out_features))
model.model.layers[i].self_attn.q_proj.bias = nn.Parameter(torch.zeros(1, model.model.layers[i].self_attn.q_proj.out_features))
return model
@classmethod
def from_llama_model(
cls,
llama_model: LlamaForCausalLM,
prelude_layers: int,
recurrent_layers: int,
coda_layers: int,
mean_recurrence: int = 4,
max_backprop_depth: int = 6,
max_recurrence: int = 8,
) -> "RecLlamaForCausalLM":
"""
Convert a regular LlamaForCausalLM model to a RecLlamaForCausalLM model.
Args:
llama_model: The source LlamaForCausalLM model
prelude_layers: Number of non-recurrent layers at the start
recurrent_layers: Number of recurrent layers in the middle
coda_layers: Number of non-recurrent layers at the end
mean_recurrence: Average number of recurrent iterations (default: 1)
max_backprop_depth: Maximum number of iterations to backpropagate through (default: 1)
Returns:
A RecLlamaForCausalLM model with weights copied from the source model
"""
# Validate layer counts
total_layers = len(llama_model.model.layers)
if prelude_layers + recurrent_layers + coda_layers != total_layers:
raise ValueError(
f"Sum of layers ({prelude_layers + recurrent_layers + coda_layers}) "
f"must equal total number of model layers ({total_layers})"
)
# Create new config based on original model's config
config = RecLlamaConfig(**llama_model.config.to_dict())
config.prelude_layers = prelude_layers
config.recurrent_layers = recurrent_layers
config.coda_layers = coda_layers
config.mean_recurrence = mean_recurrence
config.max_backprop_depth = max_backprop_depth
config.max_recurrence = max_recurrence
rec_model = cls(config)
rec_model.model.embed_tokens = llama_model.model.embed_tokens
rec_model.model.norm = llama_model.model.norm
rec_model.model.layers = llama_model.model.layers
rec_model.lm_head = llama_model.lm_head
rec_model = RecLlamaForCausalLM.add_bias_to_model(rec_model)
return rec_model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
num_steps: int = None,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
inputs_embeds = self.model.embed_tokens(input_ids)
if use_cache and past_key_values is None:
recurrent_layers = list(range(self.prelude_layers, self.prelude_layers+self.recurrent_layers))
past_key_values = RecDynamicCache(recurrent_layers)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self.model._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
position_embeddings = self.model.rotary_emb(inputs_embeds, position_ids)
# run non-recurrent blocks (prelude)
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
for block_idx, block in enumerate(self.model.layers[:self.prelude_layers]):
layer_outputs = block(
inputs_embeds,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
inputs_embeds = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
# recurrent block
inputs_embeds = self.iterate_forward(
inputs_embeds=inputs_embeds,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
num_steps=num_steps
)
# coda blocks
for block_idx, block in enumerate(self.model.layers[self.prelude_layers+self.recurrent_layers : self.prelude_layers+self.recurrent_layers+self.coda_layers]):
layer_outputs = block(
inputs_embeds,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
inputs_embeds = layer_outputs[0]
if output_attentions:
all_self_attns += (layer_outputs[1],)
inputs_embeds = self.model.norm(inputs_embeds)
if output_hidden_states:
all_hidden_states += (inputs_embeds,)
outputs = BaseModelOutputWithPast(
last_hidden_state=inputs_embeds,
past_key_values=past_key_values if use_cache else None,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(inputs_embeds[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.model.config.vocab_size, **kwargs)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + outputs if loss is not None else outputs
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
@torch._dynamo.disable(recursive=False) # type: ignore
def iterate_forward(
self,
inputs_embeds,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
position_embeddings,
num_steps=None,
):
if num_steps is None and self.num_steps is None:
num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() # type: ignore
elif hasattr(num_steps, "__len__") and len(num_steps) > 1:
num_steps_no_grad, num_steps_with_grad = num_steps
elif self.num_steps is not None:
num_steps_no_grad, num_steps_with_grad = self.num_steps, self.num_steps
else:
num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0)
with torch.no_grad():
# ultra annoying in ddp due to
# https://discuss.pytorch.org/t/does-distributeddataparallel-work-with-torch-no-grad-and-find-unused-parameters-false/122594
# for now running with find_unused_params=True enabled even though the graph structure is (technically) clear
# and all parameters are always used
for step in range(num_steps_no_grad):
for block_idx, block in enumerate(self.model.layers[self.prelude_layers:self.prelude_layers+self.recurrent_layers]):
layer_output = block(
inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
inputs_embeds = layer_output[0]
for step in range(num_steps_with_grad):
for block_idx, block in enumerate(self.model.layers[self.prelude_layers:self.prelude_layers+self.recurrent_layers]):
layer_output = block(
inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
)
inputs_embeds = layer_output[0]
return inputs_embeds
@torch._dynamo.disable(recursive=False) # type: ignore
def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Outputs are long tensors so that they can be passed through compiled functions"""
t = max(self.config.mean_recurrence, 0)
if self.training:
sigma = 0.5
mu = math.log(t) - (sigma**2 / 2)
rate = torch.zeros((1,), dtype=torch.float).log_normal_(mean=mu, std=sigma)
n = torch.poisson(rate) + 1 # Corrected Poisson sampling
n = torch.clamp(n, min=0, max=self.config.max_recurrence) # Ensure non-negative
k = torch.clamp(n, max=self.config.max_backprop_depth) # Limit k properly
else:
n = torch.tensor(self.config.mean_recurrence, dtype=torch.long)
k = torch.tensor(0, dtype=torch.long)
return n.to(dtype=torch.long), k.to(dtype=torch.long)
@torch.no_grad()
def generate(self, *args, **kwargs):
return super().generate(*args, **kwargs)