SepCache - Native Sparse Attention Cache

This repository is just a fork of transformers-community/sep_cache used for backup purposes.

Please prioritize using the official HuggingFace repository transformers-community/sep_cache; they are the same.

Table of Contents


1. Abstract

SepCache is a simple yet effective, native sparse attention Cache class proposed in the SepLLM paper - ICML 2025, which most closely aligns with the semantic distribution of natural language. In the training phase, SepLLM condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the corresponding SepCache only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation.

Notably, SepCache also delivers strong performance across many tasks in training-free scenarios. Moreover, SepLLM (or simply SepCache) is the most suitable baseline method for sparse attention mechanisms and KV compression/management, as it is the natively sparse attention mechanism that best aligns with the natural semantic distribution of language.

See more details and advanced usage in https://github.com/HKUDS/SepLLM

image

2. Usage

2.1 Sample Base Model

We recommend using models from the Llama 3 series. Our example model is based on meta-llama/Meta-Llama-3-8B-Instruct, for which we have already prepared a targeted monkey patch.

For other models, using SepCache requires minor modifications to the corresponding modeling_xxx.py file or writing a custom monkey patch. These changes are very simple -- you only need to pass arguments like input_ids to the update function of SepCache when calling it.

We will provide a detailed guide later on how to modify your modeling_xxx.py file or monkey patch file to adapt SepCache to any model.

2.2 Quick Start

2.2.1 Environment Setup

You need to install transformers>=4.53.0,<4.54.0, and we recommend using lm_eval>=0.4.9 for running evaluations. We suggest managing your Python environment with conda for better dependency control.

conda create -n sepcache python=3.10
conda activate sepcache
pip install transformers==4.53
pip install lm_eval==0.4.9

2.2.2 A Simple Example

You can use SepCache by specifying custom_generate="transformers-community/sep_cache" or custom_generate="Gausson/sep_cache" when calling the generate function. In our demo, we have already prepared sample monkey patching for the Llama 3 series models and provided some common parameters for initializing SepCache.

# requires `transformers>=4.53.0,<4.54.0`
from transformers import AutoModelForCausalLM, AutoTokenizer

# Preparing model, tokenizer, and model inputs
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto")


messages = [{"role": "user", "content": "Tell me a story about a cat."}]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True,
    enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)


# Using SepCache for generation
gen_out = model.generate(
    # usual `generate` arguments
    **model_inputs,
    do_sample=False,
    max_new_tokens=100,
    return_dict_in_generate=True,
    monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.

    # Using SepCache
    custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
    trust_remote_code=True,

    # SepCache arguments
    init_cache_size = 4,
    sep_cache_size = 128,
    local_size = 256, 
    cache_size = 512,    
    USE_MAX_SEP_CACHE = True,
    model_type = 'llama'
)

print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
assert "sepcache" in str(type(gen_out.past_key_values)).lower()

It is worth noting that you must specify the separator_token_ids: List[int] and PADDING_ID: int parameters for initializing SepCache. In the example above, we did not do this because, for convenience, in the demo above, we specified model_type = "llama", in which case separator_token_ids and PADDING_ID will be automatically filled.

However, when you use a tokenizer for a non-Llama 3 series model, you need to specify the specific values of separator_token_ids and PADDING_ID based on the tokenizer you are using. For example, the following example is based on the values obtained from a Llama 3 series tokenizer.

# Using SepCache for generation
gen_out = model.generate(
    # usual `generate` arguments
    **model_inputs,
    do_sample=False,
    max_new_tokens=100,
    return_dict_in_generate=True,
    monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.

    # Using SepCache
    custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
    trust_remote_code=True,

    # SepCache arguments
    init_cache_size = 4,
    sep_cache_size = 128,
    local_size = 256, 
    cache_size = 512,    
    USE_MAX_SEP_CACHE = True,    
    separator_token_ids = [128000, 13, 11, 30, 0, 26, 25, 198, 220, 662, 1174, 949, 758, 2652, 551, 720, 256,262],
    PADDING_ID = 128009
)

2.2.3 Frequently-Used Parameters

Below, we provide explanations and examples for the most commonly used parameters when initializing SepCache. These parameters can be passed through the generate function.

`SepCache` stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.

Frequently-Used Parameters:

    `init_cache_size: Union[int, List]`:
        The maximum number of KVs to be stored for initial tokens.
        In the paper, the hyperparameter `a` is an abbreviated alias for `init_cache_size`.                
            
    `sep_cache_size: Union[int, List]`:
        The maximum number of KVs to be stored for separator tokens.
        In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.

    `local_size: Union[int, List]`: 
        The maximum number of KVs to be stored for local tokens (i.e., sliding window).
        In the paper, the hyperparameter `w` is an abbreviated alias for `local_size`.

    `cache_size: Union[int, List]`:    
        The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache.  
        In the paper, the hyperparameter `c` is an abbreviated alias for `cache_size`.

    Concerning these four parameters above:
        When a list is passed (its length must be `layer_num`), it represents different values for each layer. 
        When an integer is passed, it means the setting is the same for all layers.
    
    
    `USE_MAX_SEP_CACHE: bool`: 
        If True, it means we only keep at most `sep_cache_size` separators' KVs.  
        If the number exceeds this limit, older separators' KVs will be discarded, keeping only the most recent `sep_cache_size` KVs. 
        In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.
      
    `separator_token_ids: List[int]`:
        The token ids of the separator tokens for the current model's tokenizer.            
        We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you 
            to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them).

    `PADDING_ID: int`:
        The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model.  

Important Note:

  • When cache_size and local_size are set to infinity (i.e., sufficiently large positive integers), and USE_MAX_SEP_CACHE is False, SepCache degenerates into a regular Cache.
  • You must always ensure that init_cache_size + sep_cache_size + local_size + left_padding_offset < cache_size. Here, left_padding_offset denotes the number of padding tokens in the record with the largest left paddings within a runtime batch. left_padding_offset can only be determined at runtime.
  • To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality: init_cache_size + sep_cache_size + local_size < cache_size, i.e., a+s+w<c in the SepLLM paper - ICML 2025 to leave room for left_padding_offset.

More Important Note: In practice, no need to do positional encoding (PE) shifting like StreamingLLM if the actual length does not exceed the pretrained max PE length (which applies to most downstream tasks.) . So, for most basic usages, just set APPLY_PE_SHIFT=False (False is also the default setting) and APPLY_PES_INSIDE=False for initialization.

2.2.4 Update Function

After initialization, another key point to note is that when using the update function of SepCache to update the keys/values and the past token IDs (which is necessary in SepCache), the current input_ids must also be provided.

key_states, value_states = past_key_values.update(                
          key_states = key_states,
          value_states = value_states,    
          input_ids = input_ids,  ## required
          layer_idx = layer_idx,     
          PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states`
          )

2.2.5 Monkey Patch Demo

To adapt the update function of SepCache mentioned in 2.2.4 Update Function, i.e., passing the current input_ids as a parameter to the update function. It is worth noting that during the prefilling stage, the shape of the input_ids tensor is [batch_size, seq_len], while during the decoding stage of auto-regressive models, the shape of the input_ids tensor should be [batch_size, 1].

In our custom_generate/generate.py file, we provide the monkey_patching function, which works by replacing the forward function in all the related instances of the XXXAttention class (for example, in the Llama 3 series model, it would be LlamaAttention) with our customized forward function (specified by the model_atten_forward parameter of the monkey_patching function).

def monkey_patching(model_obj, 
                    model_atten_forward , ## The `forward` function used to patch.
                    possible_inner_model_names: List[str] = ["model", "transformer", "gpt_neox"] , # In `XXXForCausalLM` class, the possible name of internal attribute for model. e.g.,  "model", "transformer", "gpt_neox", etc.
                    possible_layers_names: List[str] = ["layers", "h" ],  # In `XXXModel` class,  the possible name of internal attribute for decoder layers, e.g.,  "layers", "h", etc.
                    atten_attr_name_pattern_list: List[str] = ["attention", "self_attn"],  # In `XXXDecoderLayer` class, the possible name of internal attribute for self-attention, e.g.,  "attention", "self_attn", etc.
                    atten_attr_name_pattern_exclude: List[str] = ["norm", "layer"], # In `XXXDecoderLayer` class, the impossible name patterns (i.e., the patterns to be excluded) of internal attribute for self-attention module class, e.g., "norm" , etc. Sometimes, there will be some attributes like "post_attention_norm" and we do not want modify the `forward` function of it - we want to modify the `forward` function of `XXXAttention`. So, we need to exclude attribute name patterns like "norm" to accurately find the correct "forward" function to replace.
                    verbose = True):
    
    """
    This `monkey_patching` function is to
        - find the `forward` function of the `XXXAttention` class.
        - replace all the related `forward` functions of the instances of `XXXAttention` class with `model_atten_forward`.
    """
    
    ## To avoid the argument check failure, i.e., let "sepllm_kwargs" pass the check.
    transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs    

    ## Get inner model obj
    inner_model_type = PreTrainedModel
    inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
    
    ## Get the decoder layers (`nn.ModuleList`) obj
    layers_type = nn.ModuleList
    model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
    
    ## Replace all the related `forward` functions of XXXAttention class's instances.
    for i, decoder_layer in enumerate(model_layers):
        self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
        result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
        if verbose:
            decoder_class_name = get_importable_class_path(decoder_layer)
            print(f"For Layer {i}'s `{decoder_class_name}`: {result}")

    return model_layers

The monkey_patching function primarily does three things:

  • Precisely locate the forward function of all instances of the XXXAttention class.
  • Replace the forward function with the model_atten_forward function you provide.
  • Return the corresponding properties of the decoder layers found during the process, typically of type nn.ModuleList. This return value (model_layers) is only used to determine the number of layers in the current model later on (obtained by len(model_layers)).

In addition, the monkey_patching function replaces transformers.generation.GenerationMixin._validate_model_kwargs with our _validate_model_kwargs to bypass some parameter checks, as we will provide an additional sepllm_kwargs parameter to wrap the input_ids for eventual transmission to the SepCache update function.

Please ensure that the monkey_patching function accurately locates and replaces the forward function of the XXXAttention class. The current monkey_patching is designed for the Llama 3 series models. For other models, you need to appropriately modify monkey_patching to ensure its correctness of targeting and replacement ! You can monitor the monkey patching process by setting verbose=True in the monkey_patching function (or, monkey_patch_verbose = True for the generate function.)

def truncate_input_ids_4_autoregression(input_ids, key_states):
    if input_ids.shape[-1] != key_states.shape[-2]:
        assert input_ids.shape[-1] >= key_states.shape[-2]
        truncated_input_ids = input_ids[..., -key_states.shape[-2]: ]
        return truncated_input_ids
    else:
        return input_ids

The truncate_input_ids_4_autoregression function in the custom_generate/generate.py file is used to shape the input_ids tensor to [batch_size, 1] during decoding.

2.2.6 Downstream Task Evaluation

We recommend using lm_eval==0.4.9 for downstream task evaluation. You can pass model-related parameters via --model_args and generation-related parameters (including those required for initializing SepCache) via --gen_kwargs. Notably, you typically need to pass a list to separator_token_ids using a string format like "id1;id2;id3" (as shown in the example below).

lm_eval --model hf \
    --model_args pretrained=meta-llama/Meta-Llama-3-8B-Instruct,attn_implementation=flash_attention_2 \
    --tasks    gsm8k_cot  \
    --gen_kwargs custom_generate=transformers-community/sep_cache,trust_remote_code=True,monkey_patch_verbose=True,init_cache_size=4,sep_cache_size=128,local_size=256,cache_size=512,separator_token_ids="128000;13;11;30;0;26;25;198;220;662;1174;949;758;2652;551;720;256;262",PADDING_ID=128009 \
    --device cuda:0\
    --batch_size 80 2>&1 | tee log.txt

Note: SepCache is typically used in combination with Flash Attention to maximize generation efficiency.

1752618213617

2.2.7 The Detailed Signature of generate Function

Here is the detailed signature of our customized generate function for SepCache in custom_generate/generate.py file:

def generate(model,                          
            ## For SepCache                              
            init_cache_size: Union[int, List] = 4,        
            sep_cache_size: Union[int, List] = 128,
            local_size: Union[int, List]=256, 
            cache_size: Union[int, List]=512,    
            SEP_ACCUMULATION: bool = True,
            USE_MAX_SEP_CACHE: bool = False,
            SEP_PADDING_IN_BATCH: bool = False,
            separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided.
            PADDING_ID: int = None, ## required for initialization if `model_type` is not provided.

            ## For inheritance & initialization states
            past_tok_ids: List[torch.Tensor] = None,  ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.                
            key_cache: List[torch.Tensor] = None,          
            value_cache: List[torch.Tensor] = None,

            ## For debugging
            PRINT_KV_RATIO_INSIDE: bool = False,
            print_KV_inside_per_steps: int = 1000,   
            _seen_tokens: int = 0, 
            _kept_kv_ratio: List[Tuple[int]] = None,
            
            ### For positional encoding shifting
            APPLY_PE_SHIFT: bool = False,
            APPLY_PES_INSIDE: bool = False,
            _shifted_position_ids:  List[torch.Tensor] = None,
            _rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE.
            _rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors.
            pe_scaling_factor:float = 1.0,
            pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this.
            max_position_embeddings: int = 8192, 
            base: int=10000,  ## The base for RoPE.               
            
            ## For basic transformer architecture
            k_seq_dim: int=2, ## The dimension for seq_len in key tensors
            v_seq_dim: int=2, ## The dimension for seq_len in value tensors
            layer_num: int = None, ## required for initialization

            model_type: str = 'llama',  ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
            device = None,   

            ## For verbosity of monkey patching
            monkey_patch_verbose: bool = False,

            **kwargs
             ):
             ...

3. Adaptation for Other Models

Adapting SepCache to various models is simple - two approaches:

3.1 Method 1 - Monkey Patching

  • Modify the monkey_patching function to correctly locate and target the forward function of your model's XXXAttention class (e.g., LlamaAttention for Llama 3).
  • Write your custom model_atten_forward function and use monkey_patching to replace the forward function of all XXXAttention class instances. The key modification is passing input_ids to SepCache's update function.

3.2 Method 2 - Direct Code Modification (Recommended for Simplicity)

Simply edit your modeling_xxx.py file to implement:

  • Initialize past_key_values as a SepCache instance at the appropriate location (e.g., in XXXForCausalLM or XXXModel class' forward function).
  • Modify the forward function of the XXXAttention class to pass input_ids to SepCache's update function.

3.3 Important Note

The shape of input_ids is [batch_size, seq_len] during prefilling, and [batch_size, 1] during generation.

4. Other Advanced Usage

Please refer to https://github.com/HKUDS/SepLLM, in which there are detailed explanations and examples.

Downloads last month
-
Safetensors
Model size
8.03B params
Tensor type
BF16
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for Gausson/sep_cache

Finetuned
(681)
this model