Safetensors
English
mistral

MLPMemory-Mistral-wikipedia

Model Description

MLP Memory introduces a retriever-pretrained parametric memory designed to bridge retrieval-augmented generation (RAG) and traditional parametric fine-tuning.
Instead of explicitly retrieving documents during inference, this model internalizes retrieval behavior by pretraining an MLP to imitate kNN retrievers across the full pretraining corpus.

Quick Start

Step 1: Import Libraries and Initialize Models

from models import MLPMemory, MistralMLPModel
import transformers
from transformers import AutoModelForCausalLM, AutoConfig
from loguru import logger

base_lm_path = "mistralai/Mistral-7B-v0.3"
knn_generator_path = "Rubin-Wei/MLPMemory-Mistral-wikipedia"

tokenizer = transformers.AutoTokenizer.from_pretrained(base_lm_path)
base_lm = AutoModelForCausalLM.from_pretrained(base_lm_path)
config = AutoConfig.from_pretrained(knn_generator_path)
knn_generator = MistralMLPModel.from_pretrained(knn_generator_path, config=config, input_dim=config.hidden_size, output_dim=config.hidden_size)

Step 2: Prepare Models and Create Joint Model

base_lm.resize_token_embeddings(len(tokenizer))
base_lm.eval()
knn_generator.eval()

joint = MLPMemory(base_lm, knn_generator, lmbda=0.75, knn_temp=1.0).to("cuda")

Step 3: Generate Text and Compare Results

prompt = f"Answer the questions:\n\nQuestion: who sings i can't take my eyes off of you?? The answer is:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

# MLP Memory
out_ids = joint.generate(
    **inputs,
    max_new_tokens=20,
    do_sample=False
)
logger.info(f"MLP Memory output: {tokenizer.decode(out_ids[0], skip_special_tokens=True)}")

# base model
out_ids = base_lm.generate(
    **inputs,
    max_new_tokens=20,
    do_sample=False
)
logger.info(f"Base Model output: {tokenizer.decode(out_ids[0], skip_special_tokens=True)}")

πŸ“Š Generation Results Comparison:

Model Generated Continuation
Base Model "...who sings i can't take my eyes off of you?? The answer is: Andy Williams..."
+MLP Memory "...who sings i can't take my eyes off of you?? The answer is: Frankie Valli. ;)..."

Key Features

  • βš™οΈ End-to-End Differentiable β€” Unlike non-parametric retrievers, MLP Memory is fully parameterized and supports gradient flow, enabling joint optimization with the base model.
  • πŸ’Ύ Highly Compressed Knowledge β€” Compresses massive retrieval stores (e.g., 40 TB for 5 B tokens) into a compact 1 B-parameter MLP (~4 GB) while improving overall performance.
  • ⚑ Efficient Inference β€” Eliminates retrieval overhead, achieving faster inference than RAG and kNN-LM, with constant speed regardless of corpus size.
  • 🧠 Long-Term Memory β€” Functions as a durable repository capturing the full pretraining corpus, extending beyond short-term context memory.

Training Details

Setting Value
Optimizer AdamW
Learning Rate 4e-4
LR Scheduler Linear
Warmup Steps 2000
Epochs 6
Total Parameters (Memory) ~1.4B
Dataset English Wikipedia (Dec 2021)

Citation

@inproceedings{Wei2025MLPMA,
  title={MLP Memory: A Retriever-Pretrained Memory for Large Language Models},
  author={Rubin Wei and Jiaqi Cao and Jiarui Wang and Jushi Kai and Qipeng Guo and Bowen Zhou and Zhouhan Lin},
  year={2025},
  url={https://api.semanticscholar.org/CorpusID:281658735}
}

Contact

For questions and discussions, feel free to email: [email protected]

Downloads last month
27
Safetensors
Model size
2B params
Tensor type
F32
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for Rubin-Wei/MLPMemory-Mistral-wikipedia

Finetuned
(322)
this model

Dataset used to train Rubin-Wei/MLPMemory-Mistral-wikipedia

Collection including Rubin-Wei/MLPMemory-Mistral-wikipedia