WERSA: Wavelet-Enhanced Random Spectral Attention

This repository provides the official implementation of WERSA, a novel attention mechanism with linear O(n) time complexity, designed to scale Transformer models to very long sequences without a performance trade-off.

Our paper, "Scaling Attention to Very Long Sequences in Linear Time with Wavelet-Enhanced Random Spectral Attention (WERSA)", is available on arXiv:2507.08637.

πŸ”¬ The Science Behind WERSA

Standard attention mechanisms have a quadratic (O(nΒ²)) complexity that makes processing long sequences impractical. WERSA solves this by combining several powerful principles to achieve linear (O(n)) efficiency while maintaining high performance.

  • Multi-Resolution Analysis: Uses Haar wavelet transforms to decompose the input into multiple scales, capturing both local details and global context.
  • Adaptive Filtering: An MLP generates input-dependent filters and learnable scale_weights modulate each wavelet level, allowing the model to dynamically prioritize the most informative frequency components.
  • Linear Complexity via Random Features: Uses random feature projection to approximate the softmax kernel, avoiding the computation of the full quadratic attention matrix.

βš™οΈ Installation

First, ensure you have PyTorch and Hugging Face Transformers installed. Then, install the wersa package directly from this repository.

# 1. Install core dependencies (example for CUDA 12.1)
pip install torch --index-url https://download.pytorch.org/whl/cu121
pip install transformers

# 2. Install the WERSA package from this repository
pip install git+https://github.com/vincenzodentamaro/wersa.git

πŸš€ Quickstart: Building a Qwen-like Model with WERSA

You can easily build a Qwen-style causal language model with WERSA attention by importing the WersaConfig and WersaForCausalLM classes from the package.

Building an 8B Parameter Model

This snippet creates an ~8B parameter model with a configuration similar to state-of-the-art models like Qwen2-7B.

from wersa import WersaConfig, WersaForCausalLM
from transformers import AutoTokenizer

# Load a compatible tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-7B")

# Define the configuration for the 8B model
config_8b = WersaConfig(
    vocab_size=len(tokenizer),
    pad_token_id=tokenizer.pad_token_id,
    hidden_size=4096,
    num_hidden_layers=32,
    num_attention_heads=32,
    intermediate_size=11008,
    max_position_embeddings=4096
)

# Instantiate the model
model_8b = WersaForCausalLM(config_8b)
print(f"8B Model created with ~{model_8b.num_parameters() / 1e9:.2f}B parameters.")

Building a 0.6B Parameter Model

This snippet creates a smaller ~0.6B parameter model, perfect for faster experiments or deployment on more constrained hardware.

from wersa import WersaConfig, WersaForCausalLM
from transformers import AutoTokenizer

# Load a compatible tokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B")

# Define the configuration for the 0.6B model
config_0_6b = WersaConfig(
    vocab_size=len(tokenizer),
    pad_token_id=tokenizer.pad_token_id,
    hidden_size=1024,
    num_hidden_layers=24,
    num_attention_heads=16,
    intermediate_size=2816,
    max_position_embeddings=1024
)

# Instantiate the model
model_0_6b = WersaForCausalLM(config_0_6b)
print(f"0.6B Model created with ~{model_0_6b.num_parameters() / 1e9:.2f}B parameters.")

πŸ“– Training and Examples

This repository includes complete scripts to demonstrate how to pre-train these models from scratch and test their generation capabilities.

  • train_and_generate_1b.py: A full example for training a ~1B parameter model.
  • train_and_generate_8b.py: A full example for training the 8B parameter model.

πŸ“œ Citation

If you find WERSA useful in your research, please consider citing our paper:

@misc{dentamaro2025scaling,
      title={Scaling Attention to Very Long Sequences in Linear Time with Wavelet-Enhanced Random Spectral Attention (WERSA)}, 
      author={Vincenzo Dentamaro},
      year={2025},
      eprint={2507.08637},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

πŸ“„ License

This project is licensed under the Apache License 2.0.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support