Gurjot Singh commited on
Commit
e06a7a0
·
1 Parent(s): 03739b9

Add faiss integration for storage

Browse files
README.md CHANGED
@@ -465,7 +465,36 @@ For production level scenarios you will most likely want to leverage an enterpri
465
  >
466
  > You can Compile the AGE from source code and fix it.
467
 
 
 
 
 
 
 
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
 
470
  ### Insert Custom KG
471
 
 
465
  >
466
  > You can Compile the AGE from source code and fix it.
467
 
468
+ ### Using Faiss for Storage
469
+ - Install the required dependencies:
470
+ ```
471
+ pip install faiss-cpu
472
+ ```
473
+ You can also install `faiss-gpu` if you have GPU support.
474
 
475
+ - Here we are using `sentence-transformers` but you can also use `OpenAIEmbedding` model with `3072` dimensions.
476
+
477
+ ```
478
+ async def embedding_func(texts: list[str]) -> np.ndarray:
479
+ model = SentenceTransformer('all-MiniLM-L6-v2')
480
+ embeddings = model.encode(texts, convert_to_numpy=True)
481
+ return embeddings
482
+
483
+ # Initialize LightRAG with the LLM model function and embedding function
484
+ rag = LightRAG(
485
+ working_dir=WORKING_DIR,
486
+ llm_model_func=llm_model_func,
487
+ embedding_func=EmbeddingFunc(
488
+ embedding_dim=384,
489
+ max_token_size=8192,
490
+ func=embedding_func,
491
+ ),
492
+ vector_storage="FaissVectorDBStorage",
493
+ vector_db_storage_cls_kwargs={
494
+ "cosine_better_than_threshold": 0.3 # Your desired threshold
495
+ }
496
+ )
497
+ ```
498
 
499
  ### Insert Custom KG
500
 
examples/test_faiss.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import numpy as np
4
+
5
+ from dotenv import load_dotenv
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ from openai import AzureOpenAI
9
+ from lightrag import LightRAG, QueryParam
10
+ from lightrag.utils import EmbeddingFunc
11
+ from lightrag.kg.faiss_impl import FaissVectorDBStorage
12
+
13
+ # Configure Logging
14
+ logging.basicConfig(level=logging.INFO)
15
+
16
+ # Load environment variables from .env file
17
+ load_dotenv()
18
+ AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
19
+ AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
20
+ AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
21
+ AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
22
+
23
+ async def llm_model_func(
24
+ prompt,
25
+ system_prompt=None,
26
+ history_messages=[],
27
+ keyword_extraction=False,
28
+ **kwargs
29
+ ) -> str:
30
+
31
+ # Create a client for AzureOpenAI
32
+ client = AzureOpenAI(
33
+ api_key=AZURE_OPENAI_API_KEY,
34
+ api_version=AZURE_OPENAI_API_VERSION,
35
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
36
+ )
37
+
38
+ # Build the messages list for the conversation
39
+ messages = []
40
+ if system_prompt:
41
+ messages.append({"role": "system", "content": system_prompt})
42
+ if history_messages:
43
+ messages.extend(history_messages)
44
+ messages.append({"role": "user", "content": prompt})
45
+
46
+ # Call the LLM
47
+ chat_completion = client.chat.completions.create(
48
+ model=AZURE_OPENAI_DEPLOYMENT,
49
+ messages=messages,
50
+ temperature=kwargs.get("temperature", 0),
51
+ top_p=kwargs.get("top_p", 1),
52
+ n=kwargs.get("n", 1),
53
+ )
54
+
55
+ return chat_completion.choices[0].message.content
56
+
57
+
58
+ async def embedding_func(texts: list[str]) -> np.ndarray:
59
+ model = SentenceTransformer('all-MiniLM-L6-v2')
60
+ embeddings = model.encode(texts, convert_to_numpy=True)
61
+ return embeddings
62
+
63
+ def main():
64
+
65
+ WORKING_DIR = "./dickens"
66
+
67
+ # Initialize LightRAG with the LLM model function and embedding function
68
+ rag = LightRAG(
69
+ working_dir=WORKING_DIR,
70
+ llm_model_func=llm_model_func,
71
+ embedding_func=EmbeddingFunc(
72
+ embedding_dim=384,
73
+ max_token_size=8192,
74
+ func=embedding_func,
75
+ ),
76
+ vector_storage="FaissVectorDBStorage",
77
+ vector_db_storage_cls_kwargs={
78
+ "cosine_better_than_threshold": 0.3 # Your desired threshold
79
+ }
80
+ )
81
+
82
+ # Insert the custom chunks into LightRAG
83
+ book1 = open("./book_1.txt", encoding="utf-8")
84
+ book2 = open("./book_2.txt", encoding="utf-8")
85
+
86
+ rag.insert([book1.read(), book2.read()])
87
+
88
+ query_text = "What are the main themes?"
89
+
90
+ print("Result (Naive):")
91
+ print(rag.query(query_text, param=QueryParam(mode="naive")))
92
+
93
+ print("\nResult (Local):")
94
+ print(rag.query(query_text, param=QueryParam(mode="local")))
95
+
96
+ print("\nResult (Global):")
97
+ print(rag.query(query_text, param=QueryParam(mode="global")))
98
+
99
+ print("\nResult (Hybrid):")
100
+ print(rag.query(query_text, param=QueryParam(mode="hybrid")))
101
+
102
+
103
+ if __name__ == "__main__":
104
+ main()
lightrag/kg/faiss_impl.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import asyncio
4
+ import faiss
5
+ import json
6
+ import numpy as np
7
+ from tqdm.asyncio import tqdm as tqdm_async
8
+ from dataclasses import dataclass
9
+
10
+ from lightrag.utils import (
11
+ logger,
12
+ compute_mdhash_id,
13
+ )
14
+ from lightrag.base import (
15
+ BaseVectorStorage,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class FaissVectorDBStorage(BaseVectorStorage):
21
+ """
22
+ A Faiss-based Vector DB Storage for LightRAG.
23
+ Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
24
+ """
25
+ cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
26
+
27
+ def __post_init__(self):
28
+ # Grab config values if available
29
+ config = self.global_config.get("vector_db_storage_cls_kwargs", {})
30
+ self.cosine_better_than_threshold = config.get(
31
+ "cosine_better_than_threshold", self.cosine_better_than_threshold
32
+ )
33
+
34
+ # Where to save index file if you want persistent storage
35
+ self._faiss_index_file = os.path.join(
36
+ self.global_config["working_dir"], f"faiss_index_{self.namespace}.index"
37
+ )
38
+ self._meta_file = self._faiss_index_file + ".meta.json"
39
+
40
+ self._max_batch_size = self.global_config["embedding_batch_num"]
41
+ # Embedding dimension (e.g. 768) must match your embedding function
42
+ self._dim = self.embedding_func.embedding_dim
43
+
44
+ # Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
45
+ # If you have a large number of vectors, you might want IVF or other indexes.
46
+ # For demonstration, we use a simple IndexFlatIP.
47
+ self._index = faiss.IndexFlatIP(self._dim)
48
+
49
+ # Keep a local store for metadata, IDs, etc.
50
+ # Maps <int faiss_id> → metadata (including your original ID).
51
+ self._id_to_meta = {}
52
+
53
+ # Attempt to load an existing index + metadata from disk
54
+ self._load_faiss_index()
55
+
56
+ async def upsert(self, data: dict[str, dict]):
57
+ """
58
+ Insert or update vectors in the Faiss index.
59
+
60
+ data: {
61
+ "custom_id_1": {
62
+ "content": <text>,
63
+ ...metadata...
64
+ },
65
+ "custom_id_2": {
66
+ "content": <text>,
67
+ ...metadata...
68
+ },
69
+ ...
70
+ }
71
+ """
72
+ logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
73
+ if not data:
74
+ logger.warning("You are inserting empty data to the vector DB")
75
+ return []
76
+
77
+ current_time = time.time()
78
+
79
+ # Prepare data for embedding
80
+ list_data = []
81
+ contents = []
82
+ for k, v in data.items():
83
+ # Store only known meta fields if needed
84
+ meta = {mf: v[mf] for mf in self.meta_fields if mf in v}
85
+ meta["__id__"] = k
86
+ meta["__created_at__"] = current_time
87
+ list_data.append(meta)
88
+ contents.append(v["content"])
89
+
90
+ # Split into batches for embedding if needed
91
+ batches = [
92
+ contents[i : i + self._max_batch_size]
93
+ for i in range(0, len(contents), self._max_batch_size)
94
+ ]
95
+
96
+ pbar = tqdm_async(total=len(batches), desc="Generating embeddings", unit="batch")
97
+
98
+ async def wrapped_task(batch):
99
+ result = await self.embedding_func(batch)
100
+ pbar.update(1)
101
+ return result
102
+
103
+ embedding_tasks = [wrapped_task(batch) for batch in batches]
104
+ embeddings_list = await asyncio.gather(*embedding_tasks)
105
+
106
+ # Flatten the list of arrays
107
+ embeddings = np.concatenate(embeddings_list, axis=0)
108
+ if len(embeddings) != len(list_data):
109
+ logger.error(
110
+ f"Embedding size mismatch. Embeddings: {len(embeddings)}, Data: {len(list_data)}"
111
+ )
112
+ return []
113
+
114
+ # Normalize embeddings for cosine similarity (in-place)
115
+ faiss.normalize_L2(embeddings)
116
+
117
+ # Upsert logic:
118
+ # 1. Identify which vectors to remove if they exist
119
+ # 2. Remove them
120
+ # 3. Add the new vectors
121
+ existing_ids_to_remove = []
122
+ for meta, emb in zip(list_data, embeddings):
123
+ faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
124
+ if faiss_internal_id is not None:
125
+ existing_ids_to_remove.append(faiss_internal_id)
126
+
127
+ if existing_ids_to_remove:
128
+ self._remove_faiss_ids(existing_ids_to_remove)
129
+
130
+ # Step 2: Add new vectors
131
+ start_idx = self._index.ntotal
132
+ self._index.add(embeddings)
133
+
134
+ # Step 3: Store metadata + vector for each new ID
135
+ for i, meta in enumerate(list_data):
136
+ fid = start_idx + i
137
+ # Store the raw vector so we can rebuild if something is removed
138
+ meta["__vector__"] = embeddings[i].tolist()
139
+ self._id_to_meta[fid] = meta
140
+
141
+ logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
142
+ return [m["__id__"] for m in list_data]
143
+
144
+ async def query(self, query: str, top_k=5):
145
+ """
146
+ Search by a textual query; returns top_k results with their metadata + similarity distance.
147
+ """
148
+ embedding = await self.embedding_func([query])
149
+ # embedding is shape (1, dim)
150
+ embedding = np.array(embedding, dtype=np.float32)
151
+ faiss.normalize_L2(embedding) # we do in-place normalization
152
+
153
+ logger.info(
154
+ f"Query: {query}, top_k: {top_k}, threshold: {self.cosine_better_than_threshold}"
155
+ )
156
+
157
+ # Perform the similarity search
158
+ distances, indices = self._index.search(embedding, top_k)
159
+
160
+ distances = distances[0]
161
+ indices = indices[0]
162
+
163
+ results = []
164
+ for dist, idx in zip(distances, indices):
165
+ if idx == -1:
166
+ # Faiss returns -1 if no neighbor
167
+ continue
168
+
169
+ # Cosine similarity threshold
170
+ if dist < self.cosine_better_than_threshold:
171
+ continue
172
+
173
+ meta = self._id_to_meta.get(idx, {})
174
+ results.append(
175
+ {
176
+ **meta,
177
+ "id": meta.get("__id__"),
178
+ "distance": float(dist),
179
+ "created_at": meta.get("__created_at__"),
180
+ }
181
+ )
182
+
183
+ return results
184
+
185
+ @property
186
+ def client_storage(self):
187
+ # Return whatever structure LightRAG might need for debugging
188
+ return {"data": list(self._id_to_meta.values())}
189
+
190
+ async def delete(self, ids: list[str]):
191
+ """
192
+ Delete vectors for the provided custom IDs.
193
+ """
194
+ logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
195
+ to_remove = []
196
+ for cid in ids:
197
+ fid = self._find_faiss_id_by_custom_id(cid)
198
+ if fid is not None:
199
+ to_remove.append(fid)
200
+
201
+ if to_remove:
202
+ self._remove_faiss_ids(to_remove)
203
+ logger.info(f"Successfully deleted {len(to_remove)} vectors from {self.namespace}")
204
+
205
+ async def delete_entity(self, entity_name: str):
206
+ """
207
+ Delete a single entity by computing its hashed ID
208
+ the same way your code does it with `compute_mdhash_id`.
209
+ """
210
+ entity_id = compute_mdhash_id(entity_name, prefix="ent-")
211
+ logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
212
+ await self.delete([entity_id])
213
+
214
+ async def delete_entity_relation(self, entity_name: str):
215
+ """
216
+ Delete relations for a given entity by scanning metadata.
217
+ """
218
+ logger.debug(f"Searching relations for entity {entity_name}")
219
+ relations = []
220
+ for fid, meta in self._id_to_meta.items():
221
+ if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
222
+ relations.append(fid)
223
+
224
+ logger.debug(f"Found {len(relations)} relations for {entity_name}")
225
+ if relations:
226
+ self._remove_faiss_ids(relations)
227
+ logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
228
+
229
+ async def index_done_callback(self):
230
+ """
231
+ Called after indexing is done (save Faiss index + metadata).
232
+ """
233
+ self._save_faiss_index()
234
+ logger.info("Faiss index saved successfully.")
235
+
236
+ # --------------------------------------------------------------------------------
237
+ # Internal helper methods
238
+ # --------------------------------------------------------------------------------
239
+
240
+ def _find_faiss_id_by_custom_id(self, custom_id: str):
241
+ """
242
+ Return the Faiss internal ID for a given custom ID, or None if not found.
243
+ """
244
+ for fid, meta in self._id_to_meta.items():
245
+ if meta.get("__id__") == custom_id:
246
+ return fid
247
+ return None
248
+
249
+ def _remove_faiss_ids(self, fid_list):
250
+ """
251
+ Remove a list of internal Faiss IDs from the index.
252
+ Because IndexFlatIP doesn't support 'removals',
253
+ we rebuild the index excluding those vectors.
254
+ """
255
+ keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
256
+
257
+ # Rebuild the index
258
+ vectors_to_keep = []
259
+ new_id_to_meta = {}
260
+ for new_fid, old_fid in enumerate(keep_fids):
261
+ vec_meta = self._id_to_meta[old_fid]
262
+ vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
263
+ new_id_to_meta[new_fid] = vec_meta
264
+
265
+ # Re-init index
266
+ self._index = faiss.IndexFlatIP(self._dim)
267
+ if vectors_to_keep:
268
+ arr = np.array(vectors_to_keep, dtype=np.float32)
269
+ self._index.add(arr)
270
+
271
+ self._id_to_meta = new_id_to_meta
272
+
273
+ def _save_faiss_index(self):
274
+ """
275
+ Save the current Faiss index + metadata to disk so it can persist across runs.
276
+ """
277
+ faiss.write_index(self._index, self._faiss_index_file)
278
+
279
+ # Save metadata dict to JSON. Convert all keys to strings for JSON storage.
280
+ # _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
281
+ # We'll keep the int -> dict, but JSON requires string keys.
282
+ serializable_dict = {}
283
+ for fid, meta in self._id_to_meta.items():
284
+ serializable_dict[str(fid)] = meta
285
+
286
+ with open(self._meta_file, "w", encoding="utf-8") as f:
287
+ json.dump(serializable_dict, f)
288
+
289
+ def _load_faiss_index(self):
290
+ """
291
+ Load the Faiss index + metadata from disk if it exists,
292
+ and rebuild in-memory structures so we can query.
293
+ """
294
+ if not os.path.exists(self._faiss_index_file):
295
+ logger.warning("No existing Faiss index file found. Starting fresh.")
296
+ return
297
+
298
+ try:
299
+ # Load the Faiss index
300
+ self._index = faiss.read_index(self._faiss_index_file)
301
+ # Load metadata
302
+ with open(self._meta_file, "r", encoding="utf-8") as f:
303
+ stored_dict = json.load(f)
304
+
305
+ # Convert string keys back to int
306
+ self._id_to_meta = {}
307
+ for fid_str, meta in stored_dict.items():
308
+ fid = int(fid_str)
309
+ self._id_to_meta[fid] = meta
310
+
311
+ logger.info(
312
+ f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
313
+ )
314
+ except Exception as e:
315
+ logger.error(f"Failed to load Faiss index or metadata: {e}")
316
+ logger.warning("Starting with an empty Faiss index.")
317
+ self._index = faiss.IndexFlatIP(self._dim)
318
+ self._id_to_meta = {}
lightrag/lightrag.py CHANGED
@@ -60,6 +60,7 @@ STORAGES = {
60
  "PGGraphStorage": ".kg.postgres_impl",
61
  "GremlinStorage": ".kg.gremlin_impl",
62
  "PGDocStatusStorage": ".kg.postgres_impl",
 
63
  }
64
 
65
 
 
60
  "PGGraphStorage": ".kg.postgres_impl",
61
  "GremlinStorage": ".kg.gremlin_impl",
62
  "PGDocStatusStorage": ".kg.postgres_impl",
63
+ "FaissVectorDBStorage": ".kg.faiss_impl",
64
  }
65
 
66