WERSA: Wavelet-Enhanced Random Spectral Attention
This repository provides the official implementation of WERSA, a novel attention mechanism with linear O(n) time complexity, designed to scale Transformer models to very long sequences without a performance trade-off.
Our paper, "Scaling Attention to Very Long Sequences in Linear Time with Wavelet-Enhanced Random Spectral Attention (WERSA)", is available on arXiv:2507.08637.
π¬ The Science Behind WERSA
Standard attention mechanisms have a quadratic (O(nΒ²)) complexity that makes processing long sequences impractical. WERSA solves this by combining several powerful principles to achieve linear (O(n)) efficiency while maintaining high performance.
- Multi-Resolution Analysis: Uses Haar wavelet transforms to decompose the input into multiple scales, capturing both local details and global context.
- Adaptive Filtering: An MLP generates input-dependent filters and learnable scale_weights modulate each wavelet level, allowing the model to dynamically prioritize the most informative frequency components.
- Linear Complexity via Random Features: Uses random feature projection to approximate the softmax kernel, avoiding the computation of the full quadratic attention matrix.
βοΈ Installation
First, ensure you have PyTorch and Hugging Face Transformers installed. Then, install the wersa package directly from this repository.
# 1. Install core dependencies (example for CUDA 12.1)
pip install torch --index-url https://download.pytorch.org/whl/cu121
pip install transformers
# 2. Install the WERSA package from this repository
pip install git+https://github.com/vincenzodentamaro/wersa.git
π Quickstart: Building a Qwen-like Model with WERSA
You can easily build a Qwen-style causal language model with WERSA attention by importing the WersaConfig
and WersaForCausalLM
classes from the package.
Building an 8B Parameter Model
This snippet creates an ~8B parameter model with a configuration similar to state-of-the-art models like Qwen2-7B.
from wersa import WersaConfig, WersaForCausalLM
from transformers import AutoTokenizer
# Load a compatible tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B")
# Define the configuration for the 8B model
config_8b = WersaConfig(
vocab_size=len(tokenizer),
pad_token_id=tokenizer.pad_token_id,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
intermediate_size=11008,
max_position_embeddings=4096
)
# Instantiate the model
model_8b = WersaForCausalLM(config_8b)
print(f"8B Model created with ~{model_8b.num_parameters() / 1e9:.2f}B parameters.")
Building a 0.6B Parameter Model
This snippet creates a smaller ~0.6B parameter model, perfect for faster experiments or deployment on more constrained hardware.
from wersa import WersaConfig, WersaForCausalLM
from transformers import AutoTokenizer
# Load a compatible tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B")
# Define the configuration for the 0.6B model
config_0_6b = WersaConfig(
vocab_size=len(tokenizer),
pad_token_id=tokenizer.pad_token_id,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=2816,
max_position_embeddings=1024
)
# Instantiate the model
model_0_6b = WersaForCausalLM(config_0_6b)
print(f"0.6B Model created with ~{model_0_6b.num_parameters() / 1e9:.2f}B parameters.")
π Training and Examples
This repository includes complete scripts to demonstrate how to pre-train these models from scratch and test their generation capabilities.
train_and_generate_1b.py
: A full example for training a ~1B parameter model.train_and_generate_8b.py
: A full example for training the 8B parameter model.
π Citation
If you find WERSA useful in your research, please consider citing our paper:
@misc{dentamaro2025scaling,
title={Scaling Attention to Very Long Sequences in Linear Time with Wavelet-Enhanced Random Spectral Attention (WERSA)},
author={Vincenzo Dentamaro},
year={2025},
eprint={2507.08637},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
π License
This project is licensed under the Apache License 2.0.