MLPMemory
Collection
5 items
β’
Updated
β’
1
MLP Memory introduces a retriever-pretrained parametric memory designed to bridge retrieval-augmented generation (RAG) and traditional parametric fine-tuning.
Instead of explicitly retrieving documents during inference, this model internalizes retrieval behavior by pretraining an MLP to imitate kNN retrievers across the full pretraining corpus.
from models import MLPMemory, MistralMLPModel
import transformers
from transformers import AutoModelForCausalLM, AutoConfig
from loguru import logger
base_lm_path = "mistralai/Mistral-7B-v0.3"
knn_generator_path = "Rubin-Wei/MLPMemory-Mistral-wikipedia"
tokenizer = transformers.AutoTokenizer.from_pretrained(base_lm_path)
base_lm = AutoModelForCausalLM.from_pretrained(base_lm_path)
config = AutoConfig.from_pretrained(knn_generator_path)
knn_generator = MistralMLPModel.from_pretrained(knn_generator_path, config=config, input_dim=config.hidden_size, output_dim=config.hidden_size)
base_lm.resize_token_embeddings(len(tokenizer))
base_lm.eval()
knn_generator.eval()
joint = MLPMemory(base_lm, knn_generator, lmbda=0.75, knn_temp=1.0).to("cuda")
prompt = f"Answer the questions:\n\nQuestion: who sings i can't take my eyes off of you?? The answer is:"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
# MLP Memory
out_ids = joint.generate(
**inputs,
max_new_tokens=20,
do_sample=False
)
logger.info(f"MLP Memory output: {tokenizer.decode(out_ids[0], skip_special_tokens=True)}")
# base model
out_ids = base_lm.generate(
**inputs,
max_new_tokens=20,
do_sample=False
)
logger.info(f"Base Model output: {tokenizer.decode(out_ids[0], skip_special_tokens=True)}")
π Generation Results Comparison:
| Model | Generated Continuation |
|---|---|
| Base Model | "...who sings i can't take my eyes off of you?? The answer is: Andy Williams..." |
| +MLP Memory | "...who sings i can't take my eyes off of you?? The answer is: Frankie Valli. ;)..." |
| Setting | Value |
|---|---|
| Optimizer | AdamW |
| Learning Rate | 4e-4 |
| LR Scheduler | Linear |
| Warmup Steps | 2000 |
| Epochs | 6 |
| Total Parameters (Memory) | ~1.4B |
| Dataset | English Wikipedia (Dec 2021) |
@inproceedings{Wei2025MLPMA,
title={MLP Memory: A Retriever-Pretrained Memory for Large Language Models},
author={Rubin Wei and Jiaqi Cao and Jiarui Wang and Jushi Kai and Qipeng Guo and Bowen Zhou and Zhouhan Lin},
year={2025},
url={https://api.semanticscholar.org/CorpusID:281658735}
}
For questions and discussions, feel free to email: [email protected]
Base model
mistralai/Mistral-7B-v0.3