abhinavv3 commited on
Commit
174b26d
·
1 Parent(s): 2199e22

Added knn memory.Added search,retrive,add functionality to memory

Browse files
Files changed (2) hide show
  1. .vscode/settings.json +6 -0
  2. model_core/attention.py +169 -24
.vscode/settings.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "editor.quickSuggestions": {
3
+ "comments": "on",
4
+ "strings": "on"
5
+ }
6
+ }
model_core/attention.py CHANGED
@@ -1,28 +1,173 @@
 
1
  import torch.nn as nn
2
- from torch.nn import functional as F
 
 
 
 
 
 
 
3
 
 
 
 
 
4
 
5
- class CasualSelfAttention(nn.Module):
 
 
6
 
7
- def __init__(self, config):
8
- super().__init__()
9
- assert config.n_embd % config.n_head == 0
10
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
11
- self.c_proj = nn.Linear(config.n_embd, config.n_embd)
12
- self.c_proj.NANOGPT_SCALE_INIT = 1
13
- self.n_head = config.n_head
14
- self.n_embd = config.n_embd
15
-
16
- def forward(self, x):
17
- B, T, C = x.size()
18
- qkv = self.c_attn(x)
19
- q, k, v = qkv.split(self.n_embd, dim=2)
20
- k = k.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
21
- q = q.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
22
- v = v.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
23
-
24
- y = F.scaled_dot_product_attention(q, k, v, is_causal=True) #flash attention
25
-
26
- y = y.transpose(1,2).contiguous().view(B, T, C) # (B, T, C)
27
- y = self.c_proj(y)
28
- return y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import math
6
+ import faiss
7
+ from einops import rearrange, einsum
8
+ from dataclasses import dataclass
9
+ import inspect
10
+ import os
11
 
12
+ class RotaryPositionalEncoding(nn.Module):
13
+ def __init__(self, dim, max_seq_len=1024, base=10000):
14
+ super().__init__()
15
+ assert dim % 2 == 0
16
 
17
+ self.dim = dim
18
+ self.max_seq_len = max_seq_len
19
+ self.base = base
20
 
21
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) # [dim//2]
22
+ self.register_buffer('inv_freq', inv_freq)
23
+
24
+ self._cached_freqs = None
25
+ self._cached_seq_len = 0
26
+
27
+ def _get_freqs(self, seq_len, device):
28
+ if self._cached_freqs is None or seq_len > self._cached_seq_len:
29
+ t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) # [seq_len]
30
+ freqs = torch.outer(t, self.inv_freq) # [seq_len, dim//2]
31
+ cos = freqs.cos() # [seq_len, dim//2]
32
+ sin = freqs.sin()
33
+ self._cached_freqs = (cos, sin)
34
+ self._cached_seq_len = seq_len
35
+ return self._cached_freqs[0][:seq_len], self._cached_freqs[1][:seq_len]
36
+
37
+ def apply_rotary_pos_emb(self, q, k, seq_len):
38
+ assert q.shape[-1] == self.dim, f"Expected q.shape[-1] == {self.dim}, got {q.shape[-1]}"
39
+ assert k.shape[-1] == self.dim, f"Expected k.shape[-1] == {self.dim}, got {k.shape[-1]}"
40
+
41
+ device = q.device
42
+ cos, sin = self._get_freqs(seq_len, device) # both [seq_len, dim//2]
43
+
44
+ # Expand to match q/k: [1, 1, seq_len, dim//2]
45
+ cos = cos[None, None, :, :].expand(q.shape[0], q.shape[1], -1, -1)
46
+ sin = sin[None, None, :, :].expand(q.shape[0], q.shape[1], -1, -1)
47
+
48
+ def apply(x):
49
+ x1 = x[..., ::2]
50
+ x2 = x[..., 1::2]
51
+ x_rotated_even = x1 * cos - x2 * sin
52
+ x_rotated_odd = x1 * sin + x2 * cos
53
+ return torch.stack((x_rotated_even, x_rotated_odd), dim=-1).flatten(-2)
54
+
55
+ q_rot = apply(q)
56
+ k_rot = apply(k)
57
+ return q_rot, k_rot
58
+
59
+ class KNN():
60
+ def __init__(self, dim, max_memories, process_rank=0):
61
+ self.dim = dim
62
+ self.max_memories = max_memories
63
+ self.shape = (max_memories, 2, dim)
64
+ self.db_offset = 0
65
+ self.db_filepath = f"./memory_rank_{process_rank}.memmap"
66
+ self.db = np.memmap(self.db_filepath, mode='w+', dtype=np.float32, shape=self.shape)
67
+ self.index = faiss.IndexFlatL2(dim)
68
+ self.process_rank = process_rank
69
+
70
+ def add_to_db(self, new_data):
71
+ new_data_len = new_data.shape[0] # B*T
72
+ ids = (np.arange(new_data_len) + self.db_offset) % self.max_memories
73
+ self.db[ids] = new_data.detach().cpu().numpy()
74
+ self.db_offset = (self.db_offset + new_data_len) % self.max_memories
75
+ self.db.flush()
76
+
77
+ def search_and_retrieve(self, query_vecs, topk):
78
+ distances, indices = self.index.search(query_vecs, topk)
79
+ kvs = self.db[indices]
80
+ return kvs
81
+
82
+ def add(self, new_data):
83
+ new_data = new_data.flatten(0, 1) #(B,T,2,C) --> (B*T,2,C)
84
+ self.add_to_db(new_data)
85
+ keys, vals = new_data.unbind(dim=-2) #(B,T,C)
86
+ keys = keys.detach().cpu().numpy()
87
+ keys = np.ascontiguousarray(keys)
88
+ self.index.add(keys)
89
+
90
+ def search(self, query_vecs, topk):
91
+ query_batch_size, query_seq_len = query_vecs.shape[0], query_vecs.shape[1]
92
+ query_vecs = query_vecs.flatten(0, 1) #(B,T,C) --> (B*T,C)
93
+ kvs = self.search_and_retrieve(np.ascontiguousarray(query_vecs.detach().cpu().numpy()), topk)
94
+ kvs = torch.tensor(kvs) #(B*T,TOPK,2,C)
95
+ kvs = torch.unflatten(kvs, 0, (query_batch_size, query_seq_len)) #(B*T,TOPK,2,C) --> (B,T,TOPK,2,C)
96
+ return kvs
97
+
98
+ def clear(self):
99
+ self.index.reset()
100
+ self.db[:] = 0
101
+ self.db_offset = 0
102
+
103
+ def cleanup(self):
104
+ #call it after all training completed
105
+ try:
106
+ if os.path.exists(self.db_filepath):
107
+ os.remove(self.db_filepath)
108
+ except:
109
+ pass
110
+
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+ # import torch.nn as nn
147
+ # from torch.nn import functional as F
148
+
149
+
150
+ # class CasualSelfAttention(nn.Module):
151
+
152
+ # def __init__(self, config):
153
+ # super().__init__()
154
+ # assert config.n_embd % config.n_head == 0
155
+ # self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
156
+ # self.c_proj = nn.Linear(config.n_embd, config.n_embd)
157
+ # self.c_proj.NANOGPT_SCALE_INIT = 1
158
+ # self.n_head = config.n_head
159
+ # self.n_embd = config.n_embd
160
+
161
+ # def forward(self, x):
162
+ # B, T, C = x.size()
163
+ # qkv = self.c_attn(x)
164
+ # q, k, v = qkv.split(self.n_embd, dim=2)
165
+ # k = k.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
166
+ # q = q.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
167
+ # v = v.view(B, T, self.n_head, C // self.n_head).transpose(1,2) # (B, nh, T, hs)
168
+
169
+ # y = F.scaled_dot_product_attention(q, k, v, is_causal=True) #flash attention
170
+
171
+ # y = y.transpose(1,2).contiguous().view(B, T, C) # (B, T, C)
172
+ # y = self.c_proj(y)
173
+ # return y