SepCache - Native Sparse Attention Cache
This repository is just a fork of transformers-community/sep_cache
used for backup purposes.
Please prioritize using the official HuggingFace repository transformers-community/sep_cache
; they are the same.
Table of Contents
1. Abstract
SepCache
is a simple yet effective, native sparse attention Cache
class proposed in the SepLLM paper - ICML 2025
, which most closely aligns with the semantic distribution of natural language. In the training phase, SepLLM
condenses the segment information into the KV of the separator that divides the segment. In the inference phase, the corresponding SepCache
only needs to store the KVs of initial tokens, separator tokens, and recent tokens for generation.
Notably, SepCache
also delivers strong performance across many tasks in training-free scenarios. Moreover, SepLLM
(or simply SepCache
) is the most suitable baseline method for sparse attention mechanisms and KV compression/management, as it is the natively sparse attention mechanism that best aligns with the natural semantic distribution of language.
See more details and advanced usage in https://github.com/HKUDS/SepLLM
2. Usage
2.1 Sample Base Model
We recommend using models from the Llama 3 series. Our example model is based on meta-llama/Meta-Llama-3-8B-Instruct
, for which we have already prepared a targeted monkey patch
.
For other models, using SepCache
requires minor modifications to the corresponding modeling_xxx.py
file or writing a custom monkey patch. These changes are very simple -- you only need to pass arguments like input_ids
to the update
function of SepCache
when calling it.
We will provide a detailed guide later on how to modify your modeling_xxx.py
file or monkey patch
file to adapt SepCache
to any model.
2.2 Quick Start
2.2.1 Environment Setup
You need to install transformers>=4.53.0,<4.54.0
, and we recommend using lm_eval>=0.4.9
for running evaluations. We suggest managing your Python environment with conda
for better dependency control.
conda create -n sepcache python=3.10
conda activate sepcache
pip install transformers==4.53
pip install lm_eval==0.4.9
2.2.2 A Simple Example
You can use SepCache
by specifying custom_generate="transformers-community/sep_cache"
or custom_generate="Gausson/sep_cache"
when calling the generate
function. In our demo, we have already prepared sample monkey patching for the Llama 3 series
models and provided some common parameters for initializing SepCache
.
# requires `transformers>=4.53.0,<4.54.0`
from transformers import AutoModelForCausalLM, AutoTokenizer
# Preparing model, tokenizer, and model inputs
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto")
messages = [{"role": "user", "content": "Tell me a story about a cat."}]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Using SepCache for generation
gen_out = model.generate(
# usual `generate` arguments
**model_inputs,
do_sample=False,
max_new_tokens=100,
return_dict_in_generate=True,
monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.
# Using SepCache
custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
trust_remote_code=True,
# SepCache arguments
init_cache_size = 4,
sep_cache_size = 128,
local_size = 256,
cache_size = 512,
USE_MAX_SEP_CACHE = True,
model_type = 'llama'
)
print(tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True))
assert "sepcache" in str(type(gen_out.past_key_values)).lower()
It is worth noting that you must specify the separator_token_ids: List[int]
and PADDING_ID: int
parameters for initializing SepCache
. In the example above, we did not do this because, for convenience, in the demo above, we specified model_type = "llama"
, in which case separator_token_ids
and PADDING_ID
will be automatically filled.
However, when you use a tokenizer for a non-Llama 3 series model, you need to specify the specific values of separator_token_ids
and PADDING_ID
based on the tokenizer you are using. For example, the following example is based on the values obtained from a Llama 3 series tokenizer.
# Using SepCache for generation
gen_out = model.generate(
# usual `generate` arguments
**model_inputs,
do_sample=False,
max_new_tokens=100,
return_dict_in_generate=True,
monkey_patch_verbose = True, # To see which functions are actually being monkey patched for `SepCache`.
# Using SepCache
custom_generate="transformers-community/sep_cache", ## Alternatively, you can use `Gausson/sep_cache`
trust_remote_code=True,
# SepCache arguments
init_cache_size = 4,
sep_cache_size = 128,
local_size = 256,
cache_size = 512,
USE_MAX_SEP_CACHE = True,
separator_token_ids = [128000, 13, 11, 30, 0, 26, 25, 198, 220, 662, 1174, 949, 758, 2652, 551, 720, 256,262],
PADDING_ID = 128009
)
2.2.3 Frequently-Used Parameters
Below, we provide explanations and examples for the most commonly used parameters when initializing SepCache
. These parameters can be passed through the generate
function.
`SepCache` stores the Key and Value states as lists of tensors, two lists for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
Frequently-Used Parameters:
`init_cache_size: Union[int, List]`:
The maximum number of KVs to be stored for initial tokens.
In the paper, the hyperparameter `a` is an abbreviated alias for `init_cache_size`.
`sep_cache_size: Union[int, List]`:
The maximum number of KVs to be stored for separator tokens.
In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.
`local_size: Union[int, List]`:
The maximum number of KVs to be stored for local tokens (i.e., sliding window).
In the paper, the hyperparameter `w` is an abbreviated alias for `local_size`.
`cache_size: Union[int, List]`:
The maximum number of KVs to be stored for all the tokens, i.e., the size for the whole KV cache.
In the paper, the hyperparameter `c` is an abbreviated alias for `cache_size`.
Concerning these four parameters above:
When a list is passed (its length must be `layer_num`), it represents different values for each layer.
When an integer is passed, it means the setting is the same for all layers.
`USE_MAX_SEP_CACHE: bool`:
If True, it means we only keep at most `sep_cache_size` separators' KVs.
If the number exceeds this limit, older separators' KVs will be discarded, keeping only the most recent `sep_cache_size` KVs.
In the paper, the hyperparameter `s` is an abbreviated alias for `sep_cache_size`.
`separator_token_ids: List[int]`:
The token ids of the separator tokens for the current model's tokenizer.
We have some examples, such as the Llama-3 series models, where setting `model_type='llama'` allows you
to skip setting `separator_token_ids` and `PADDING_ID` (SepCache will auto-fill them).
`PADDING_ID: int`:
The token id of the padding token. You can just set `PADDING_ID` to the id of "<|endoftext|>" token of the tokenizer for the pretrained model.
Important Note:
- When
cache_size
andlocal_size
are set to infinity (i.e., sufficiently large positive integers), andUSE_MAX_SEP_CACHE
isFalse
,SepCache
degenerates into a regular Cache. - You must always ensure that
init_cache_size
+sep_cache_size
+local_size
+left_padding_offset
<cache_size
. Here,left_padding_offset
denotes the number of padding tokens in the record with the largest left paddings within a runtime batch.left_padding_offset
can only be determined at runtime. - To guarantee the above inequality always holds during runtime, when setting, you can intentionally create a sufficient margin between both sides of the following inequality:
init_cache_size
+sep_cache_size
+local_size
<cache_size
, i.e.,a
+s
+w
<c
in the SepLLM paper - ICML 2025 to leave room forleft_padding_offset
.
More Important Note: In practice, no need to do positional encoding (PE) shifting like StreamingLLM if the actual length does not exceed the pretrained max PE length (which applies to most downstream tasks.) . So, for most basic usages, just set APPLY_PE_SHIFT=False
(False
is also the default setting) and APPLY_PES_INSIDE=False
for initialization.
2.2.4 Update Function
After initialization, another key point to note is that when using the update
function of SepCache
to update the keys/values and the past token IDs (which is necessary in SepCache), the current input_ids
must also be provided.
key_states, value_states = past_key_values.update(
key_states = key_states,
value_states = value_states,
input_ids = input_ids, ## required
layer_idx = layer_idx,
PREFILLING_FLAG = q_len > 1, ## `q_len` is the sequence length of the current `query_states`
)
2.2.5 Monkey Patch Demo
To adapt the update
function of SepCache
mentioned in 2.2.4 Update Function
, i.e., passing the current input_ids
as a parameter to the update
function. It is worth noting that during the prefilling stage, the shape of the input_ids tensor is [batch_size, seq_len]
, while during the decoding stage of auto-regressive models, the shape of the input_ids
tensor should be [batch_size, 1]
.
In our custom_generate/generate.py
file, we provide the monkey_patching
function, which works by replacing the forward
function in all the related instances of the XXXAttention
class (for example, in the Llama 3 series model, it would be LlamaAttention
) with our customized forward function (specified by the model_atten_forward
parameter of the monkey_patching
function).
def monkey_patching(model_obj,
model_atten_forward , ## The `forward` function used to patch.
possible_inner_model_names: List[str] = ["model", "transformer", "gpt_neox"] , # In `XXXForCausalLM` class, the possible name of internal attribute for model. e.g., "model", "transformer", "gpt_neox", etc.
possible_layers_names: List[str] = ["layers", "h" ], # In `XXXModel` class, the possible name of internal attribute for decoder layers, e.g., "layers", "h", etc.
atten_attr_name_pattern_list: List[str] = ["attention", "self_attn"], # In `XXXDecoderLayer` class, the possible name of internal attribute for self-attention, e.g., "attention", "self_attn", etc.
atten_attr_name_pattern_exclude: List[str] = ["norm", "layer"], # In `XXXDecoderLayer` class, the impossible name patterns (i.e., the patterns to be excluded) of internal attribute for self-attention module class, e.g., "norm" , etc. Sometimes, there will be some attributes like "post_attention_norm" and we do not want modify the `forward` function of it - we want to modify the `forward` function of `XXXAttention`. So, we need to exclude attribute name patterns like "norm" to accurately find the correct "forward" function to replace.
verbose = True):
"""
This `monkey_patching` function is to
- find the `forward` function of the `XXXAttention` class.
- replace all the related `forward` functions of the instances of `XXXAttention` class with `model_atten_forward`.
"""
## To avoid the argument check failure, i.e., let "sepllm_kwargs" pass the check.
transformers.generation.GenerationMixin._validate_model_kwargs = _validate_model_kwargs
## Get inner model obj
inner_model_type = PreTrainedModel
inner_model = find_inner_attribute(model_obj, possible_inner_model_names, inner_model_type)
## Get the decoder layers (`nn.ModuleList`) obj
layers_type = nn.ModuleList
model_layers = find_inner_attribute(inner_model, possible_layers_names, layers_type)
## Replace all the related `forward` functions of XXXAttention class's instances.
for i, decoder_layer in enumerate(model_layers):
self_attn_module = find_attribute_name(decoder_layer, atten_attr_name_pattern_list, atten_attr_name_pattern_exclude, nn.Module)
result = monkey_patch_by_class_path(self_attn_module, model_atten_forward)
if verbose:
decoder_class_name = get_importable_class_path(decoder_layer)
print(f"For Layer {i}'s `{decoder_class_name}`: {result}")
return model_layers
The monkey_patching
function primarily does three things:
- Precisely locate the
forward
function of all instances of theXXXAttention
class. - Replace the
forward
function with themodel_atten_forward
function you provide. - Return the corresponding properties of the decoder layers found during the process, typically of type
nn.ModuleList
. This return value (model_layers
) is only used to determine the number of layers in the current model later on (obtained bylen(model_layers)
).
In addition, the monkey_patching
function replaces transformers.generation.GenerationMixin._validate_model_kwargs
with our _validate_model_kwargs
to bypass some parameter checks, as we will provide an additional sepllm_kwargs
parameter to wrap the input_ids
for eventual transmission to the SepCache
update
function.
Please ensure that the monkey_patching
function accurately locates and replaces the forward
function of the XXXAttention
class. The current monkey_patching
is designed for the Llama 3 series
models. For other models, you need to appropriately modify monkey_patching
to ensure its correctness of targeting and replacement ! You can monitor the monkey patching process by setting verbose=True
in the monkey_patching
function (or, monkey_patch_verbose = True
for the generate
function.)
def truncate_input_ids_4_autoregression(input_ids, key_states):
if input_ids.shape[-1] != key_states.shape[-2]:
assert input_ids.shape[-1] >= key_states.shape[-2]
truncated_input_ids = input_ids[..., -key_states.shape[-2]: ]
return truncated_input_ids
else:
return input_ids
The truncate_input_ids_4_autoregression
function in the custom_generate/generate.py
file is used to shape the input_ids
tensor to [batch_size, 1]
during decoding.
2.2.6 Downstream Task Evaluation
We recommend using lm_eval==0.4.9
for downstream task evaluation. You can pass model-related parameters via --model_args
and generation-related parameters (including those required for initializing SepCache
) via --gen_kwargs
. Notably, you typically need to pass a list
to separator_token_ids
using a string format like "id1;id2;id3"
(as shown in the example below).
lm_eval --model hf \
--model_args pretrained=meta-llama/Meta-Llama-3-8B-Instruct,attn_implementation=flash_attention_2 \
--tasks gsm8k_cot \
--gen_kwargs custom_generate=transformers-community/sep_cache,trust_remote_code=True,monkey_patch_verbose=True,init_cache_size=4,sep_cache_size=128,local_size=256,cache_size=512,separator_token_ids="128000;13;11;30;0;26;25;198;220;662;1174;949;758;2652;551;720;256;262",PADDING_ID=128009 \
--device cuda:0\
--batch_size 80 2>&1 | tee log.txt
Note: SepCache
is typically used in combination with Flash Attention
to maximize generation efficiency.
2.2.7 The Detailed Signature of generate
Function
Here is the detailed signature of our customized generate
function for SepCache
in custom_generate/generate.py
file:
def generate(model,
## For SepCache
init_cache_size: Union[int, List] = 4,
sep_cache_size: Union[int, List] = 128,
local_size: Union[int, List]=256,
cache_size: Union[int, List]=512,
SEP_ACCUMULATION: bool = True,
USE_MAX_SEP_CACHE: bool = False,
SEP_PADDING_IN_BATCH: bool = False,
separator_token_ids: List[int] = None, ## required for initialization if `model_type` is not provided.
PADDING_ID: int = None, ## required for initialization if `model_type` is not provided.
## For inheritance & initialization states
past_tok_ids: List[torch.Tensor] = None, ## It saves all the token ids corresponding to the saved KVs for all layers in SepCache.
key_cache: List[torch.Tensor] = None,
value_cache: List[torch.Tensor] = None,
## For debugging
PRINT_KV_RATIO_INSIDE: bool = False,
print_KV_inside_per_steps: int = 1000,
_seen_tokens: int = 0,
_kept_kv_ratio: List[Tuple[int]] = None,
### For positional encoding shifting
APPLY_PE_SHIFT: bool = False,
APPLY_PES_INSIDE: bool = False,
_shifted_position_ids: List[torch.Tensor] = None,
_rope_unsqueeze_dim: int = 1, ## The unsqueeze_dim when applying RoPE.
_rope_seq_dim: int=1, ## The seq_len dimension for the `cos` or `sin` tensors.
pe_scaling_factor:float = 1.0,
pe_dim:int=128, ## The number of dims for positional encoding. Typically, just set the `head_dim` to this.
max_position_embeddings: int = 8192,
base: int=10000, ## The base for RoPE.
## For basic transformer architecture
k_seq_dim: int=2, ## The dimension for seq_len in key tensors
v_seq_dim: int=2, ## The dimension for seq_len in value tensors
layer_num: int = None, ## required for initialization
model_type: str = 'llama', ## The model type for running the example. choose from ['llama', 'pythia','falcon'].
device = None,
## For verbosity of monkey patching
monkey_patch_verbose: bool = False,
**kwargs
):
...
3. Adaptation for Other Models
Adapting SepCache
to various models is simple - two approaches:
3.1 Method 1 - Monkey Patching
- Modify the
monkey_patching
function to correctly locate and target theforward
function of your model'sXXXAttention
class (e.g.,LlamaAttention
for Llama 3). - Write your custom
model_atten_forward
function and usemonkey_patching
to replace theforward
function of allXXXAttention
class instances. The key modification is passinginput_ids
toSepCache
'supdate
function.
3.2 Method 2 - Direct Code Modification (Recommended for Simplicity)
Simply edit your modeling_xxx.py
file to implement:
- Initialize
past_key_values
as aSepCache
instance at the appropriate location (e.g., inXXXForCausalLM
orXXXModel
class'forward
function). - Modify the
forward
function of theXXXAttention
class to passinput_ids
toSepCache
'supdate
function.
3.3 Important Note
The shape of input_ids
is [batch_size, seq_len]
during prefilling, and [batch_size, 1]
during generation.
4. Other Advanced Usage
Please refer to https://github.com/HKUDS/SepLLM, in which there are detailed explanations and examples.
- Downloads last month
- -
Model tree for Gausson/sep_cache
Base model
meta-llama/Meta-Llama-3-8B-Instruct