File size: 10,282 Bytes
0cac660
 
f6d6286
0cac660
 
 
f6d6286
 
 
0cac660
f6d6286
0cac660
 
 
 
 
 
 
f6d6286
 
 
 
 
 
0cac660
f6d6286
0cac660
f6d6286
 
0cac660
f6d6286
 
 
 
 
0cac660
f6d6286
 
 
0cac660
 
 
 
f6d6286
 
0cac660
 
 
f6d6286
 
 
0cac660
 
f6d6286
 
0cac660
 
f6d6286
 
 
 
 
0cac660
 
 
f6d6286
 
 
 
 
 
 
0cac660
 
 
f6d6286
 
 
 
0cac660
f6d6286
 
 
 
 
 
 
 
 
 
 
0cac660
 
 
f6d6286
 
0cac660
f6d6286
 
 
0cac660
 
f6d6286
 
 
0cac660
f6d6286
 
0cac660
 
 
 
 
 
f6d6286
0cac660
 
f6d6286
0cac660
 
 
f6d6286
0cac660
f6d6286
0cac660
f6d6286
0cac660
f6d6286
0cac660
 
f6d6286
0cac660
f6d6286
0cac660
f6d6286
0cac660
 
f6d6286
0cac660
f6d6286
0cac660
f6d6286
0cac660
 
f6d6286
0cac660
f6d6286
0cac660
f6d6286
0cac660
 
f6d6286
0cac660
f6d6286
0cac660
f6d6286
0cac660
f6d6286
0cac660
 
 
 
 
 
f6d6286
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
��

# >��� Memorizing Transformer with Grouped Query Attention



An extended GPT-2-style large language model (LLM) that implements core components from the research paper ** Memorizing Transformers  (Wu et al., 2022)**.  

This project incorporates Grouped Query Attention (GQA), KNN-based memory retrieval, XL-style attention, and Rotary Positional Encoding (RoPE).  

The training pipeline supports distributed training, data parallelism, and sharded dataset streaming.



---



## =�,� Key Features



- ' **Grouped Query Attention**: Efficient query representation by grouping multiple attention heads for shared K/V access  

- ' **KNN-based Memory**: Long-term memory retrieval from past activations using a learned KNN mechanism  

- ' **XL-style Attention**: Recurrence-based memory layers adapted for KNN and grouped attention logic  

- ' **Rotary Positional Encoding**: More efficient and generalizable positional representation than vanilla sin-cos encoding  

- ' **Sharded Dataset Loader**: Handles large datasets with sharding and supports data parallelism via PyTorch DDP  

- ' **Custom Memory Clearing Logic**: Memory reset and lifespan mechanisms tuned for stability and performance during training  

- ' **Mixed Precision & DDP Training**: Efficient large-scale training using `torch.autocast` and `torchrun`  



---



## =��� Project Structure



```bash

MEM_TRANSFORMER/

%%% configs/

%   %%% config.json                  # Model + training hyperparameters

%

%%% data/

%   %%% edu_fineweb/                 # Token-sharded training data

%   %   %%% train_000001.npy

%   %   %%% train_000002.npy

%   %   %%% test_000001.npy

%   %%% hellaswag/

%   %   %%% hellaswag_val.jsonl

%   %%% fineweb.py                   # Sharding logic with memory-aligned sequence control

%

%%% model_core/

%   %%% __init__.py

%   %%% attention.py                 # Grouped Query Attention, KNN & XL attention logic.Rotary Positional Encoding implementation

%   %%% model.py                     # Transformer model with memory and RoPE support

%   %%% dataloader.py                # Memory-aware DataLoader

%   %%% training.py                  # train_memgpt function

%

%%% scripts/

%   %%% train.py                     # Training script (DDP-compatible)

%   %%% evaluate.py                  # Evaluation on benchmarks

%   %%% generate.py                  # Text generation from trained model

%

%%% evaluation/

%   %%% __init__.py

%   %%% hellaswag.py                 # HellaSwag data loader

%   %%% val_hellaswag.py             # Evaluation logic with loss-based scoring

%

%%% logs/

%   %%% log.txt                      # Training logs

%   %%% model_*.pt                   # Checkpoints

%

%%% .gitignore

%%% README.md

%%% requirements.txt



```

## �&� Configuration 

Edit the config file at configs/config.json to adjust model and training hyperparameters:

```json

{

  "model": {

    "block_size": 1024,

    "vocab_size": 50304,

    "n_layer": 12,

    "n_head": 12,

    "n_embd": 768,

    "n_kv_head": 4,

    "max_knn_memories": 81920

  },

  "training": {

    "max_steps": 19073,

    "log_dir": "log",

    "total_batch_size": 2048,

    "B": 64,

    "T": 1024,

    "max_lr": 0.0006,

    "min_lr": 0.00006,

    "warmup_steps": 715,

    "weight_decay": 0.1,

    "learning_rate": 0.0006

  }

}





```

```

```

## =؀� Training



�%� Single-GPU Training

```bash

python scripts/train.py



```

�%� Distributed Training (Multi-GPU with DDP)

```bash

torchrun --nproc_per_node=NUM_GPUS scripts/train.py

```

Replace NUM_GPUS with the number of GPUs available.

```



##=��� Evaluation

Evaluate on the HellaSwag benchmark

```

=��� Evaluation

Evaluate on the HellaSwag benchmark:

python scripts/evaluate.py



Make sure the file data/hellaswag/hellaswag_val.jsonl is present.

The evaluation uses completion scoring based on masked loss comparisons across candidate endings.



>��� Attention Mechanism Notes

>��� Grouped Query Attention (GQA)

n_head query heads



n_kv_head shared key/value heads



Query heads are grouped and averaged before memory lookup



More efficient than per-head K/V for large models



>��� KNN Memory Integration

A maximum memory buffer of 81920 tokens (max_knn_memories)



Query vectors are projected and grouped for efficient KNN search



Careful shape transformations ensure fast grouped matching



>��� XL-style Attention + Memory Clearing

Recurrence with cached memory states



Implements custom memory clearing to avoid stale token influence



Helps stability in long training runs



=ء� Positional Encoding

Rotary Positional Encoding (RoPE) replaces standard sin/cos



RoPE improves generalization over longer contexts



Implemented in model_core/rotary.py



>��� Dataloader & Dataset Handling

Sharded training data using .npy files



Matching stride and memory alignment logic



Optimized for DDP compatibility and large-scale throughput



Code in model_core/dataloader.py and data/fineweb.py



=��� Requirements

Install dependencies:

```bash

pip install -r requirements.txt

```

Ensure PyTorch and CUDA versions match your GPU setup.