Gurjot Singh
commited on
Commit
·
e06a7a0
1
Parent(s):
03739b9
Add faiss integration for storage
Browse files- README.md +29 -0
- examples/test_faiss.py +104 -0
- lightrag/kg/faiss_impl.py +318 -0
- lightrag/lightrag.py +1 -0
README.md
CHANGED
@@ -465,7 +465,36 @@ For production level scenarios you will most likely want to leverage an enterpri
|
|
465 |
>
|
466 |
> You can Compile the AGE from source code and fix it.
|
467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
|
470 |
### Insert Custom KG
|
471 |
|
|
|
465 |
>
|
466 |
> You can Compile the AGE from source code and fix it.
|
467 |
|
468 |
+
### Using Faiss for Storage
|
469 |
+
- Install the required dependencies:
|
470 |
+
```
|
471 |
+
pip install faiss-cpu
|
472 |
+
```
|
473 |
+
You can also install `faiss-gpu` if you have GPU support.
|
474 |
|
475 |
+
- Here we are using `sentence-transformers` but you can also use `OpenAIEmbedding` model with `3072` dimensions.
|
476 |
+
|
477 |
+
```
|
478 |
+
async def embedding_func(texts: list[str]) -> np.ndarray:
|
479 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
480 |
+
embeddings = model.encode(texts, convert_to_numpy=True)
|
481 |
+
return embeddings
|
482 |
+
|
483 |
+
# Initialize LightRAG with the LLM model function and embedding function
|
484 |
+
rag = LightRAG(
|
485 |
+
working_dir=WORKING_DIR,
|
486 |
+
llm_model_func=llm_model_func,
|
487 |
+
embedding_func=EmbeddingFunc(
|
488 |
+
embedding_dim=384,
|
489 |
+
max_token_size=8192,
|
490 |
+
func=embedding_func,
|
491 |
+
),
|
492 |
+
vector_storage="FaissVectorDBStorage",
|
493 |
+
vector_db_storage_cls_kwargs={
|
494 |
+
"cosine_better_than_threshold": 0.3 # Your desired threshold
|
495 |
+
}
|
496 |
+
)
|
497 |
+
```
|
498 |
|
499 |
### Insert Custom KG
|
500 |
|
examples/test_faiss.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import logging
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from dotenv import load_dotenv
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
|
8 |
+
from openai import AzureOpenAI
|
9 |
+
from lightrag import LightRAG, QueryParam
|
10 |
+
from lightrag.utils import EmbeddingFunc
|
11 |
+
from lightrag.kg.faiss_impl import FaissVectorDBStorage
|
12 |
+
|
13 |
+
# Configure Logging
|
14 |
+
logging.basicConfig(level=logging.INFO)
|
15 |
+
|
16 |
+
# Load environment variables from .env file
|
17 |
+
load_dotenv()
|
18 |
+
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
|
19 |
+
AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
|
20 |
+
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
|
21 |
+
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
|
22 |
+
|
23 |
+
async def llm_model_func(
|
24 |
+
prompt,
|
25 |
+
system_prompt=None,
|
26 |
+
history_messages=[],
|
27 |
+
keyword_extraction=False,
|
28 |
+
**kwargs
|
29 |
+
) -> str:
|
30 |
+
|
31 |
+
# Create a client for AzureOpenAI
|
32 |
+
client = AzureOpenAI(
|
33 |
+
api_key=AZURE_OPENAI_API_KEY,
|
34 |
+
api_version=AZURE_OPENAI_API_VERSION,
|
35 |
+
azure_endpoint=AZURE_OPENAI_ENDPOINT,
|
36 |
+
)
|
37 |
+
|
38 |
+
# Build the messages list for the conversation
|
39 |
+
messages = []
|
40 |
+
if system_prompt:
|
41 |
+
messages.append({"role": "system", "content": system_prompt})
|
42 |
+
if history_messages:
|
43 |
+
messages.extend(history_messages)
|
44 |
+
messages.append({"role": "user", "content": prompt})
|
45 |
+
|
46 |
+
# Call the LLM
|
47 |
+
chat_completion = client.chat.completions.create(
|
48 |
+
model=AZURE_OPENAI_DEPLOYMENT,
|
49 |
+
messages=messages,
|
50 |
+
temperature=kwargs.get("temperature", 0),
|
51 |
+
top_p=kwargs.get("top_p", 1),
|
52 |
+
n=kwargs.get("n", 1),
|
53 |
+
)
|
54 |
+
|
55 |
+
return chat_completion.choices[0].message.content
|
56 |
+
|
57 |
+
|
58 |
+
async def embedding_func(texts: list[str]) -> np.ndarray:
|
59 |
+
model = SentenceTransformer('all-MiniLM-L6-v2')
|
60 |
+
embeddings = model.encode(texts, convert_to_numpy=True)
|
61 |
+
return embeddings
|
62 |
+
|
63 |
+
def main():
|
64 |
+
|
65 |
+
WORKING_DIR = "./dickens"
|
66 |
+
|
67 |
+
# Initialize LightRAG with the LLM model function and embedding function
|
68 |
+
rag = LightRAG(
|
69 |
+
working_dir=WORKING_DIR,
|
70 |
+
llm_model_func=llm_model_func,
|
71 |
+
embedding_func=EmbeddingFunc(
|
72 |
+
embedding_dim=384,
|
73 |
+
max_token_size=8192,
|
74 |
+
func=embedding_func,
|
75 |
+
),
|
76 |
+
vector_storage="FaissVectorDBStorage",
|
77 |
+
vector_db_storage_cls_kwargs={
|
78 |
+
"cosine_better_than_threshold": 0.3 # Your desired threshold
|
79 |
+
}
|
80 |
+
)
|
81 |
+
|
82 |
+
# Insert the custom chunks into LightRAG
|
83 |
+
book1 = open("./book_1.txt", encoding="utf-8")
|
84 |
+
book2 = open("./book_2.txt", encoding="utf-8")
|
85 |
+
|
86 |
+
rag.insert([book1.read(), book2.read()])
|
87 |
+
|
88 |
+
query_text = "What are the main themes?"
|
89 |
+
|
90 |
+
print("Result (Naive):")
|
91 |
+
print(rag.query(query_text, param=QueryParam(mode="naive")))
|
92 |
+
|
93 |
+
print("\nResult (Local):")
|
94 |
+
print(rag.query(query_text, param=QueryParam(mode="local")))
|
95 |
+
|
96 |
+
print("\nResult (Global):")
|
97 |
+
print(rag.query(query_text, param=QueryParam(mode="global")))
|
98 |
+
|
99 |
+
print("\nResult (Hybrid):")
|
100 |
+
print(rag.query(query_text, param=QueryParam(mode="hybrid")))
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
main()
|
lightrag/kg/faiss_impl.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import asyncio
|
4 |
+
import faiss
|
5 |
+
import json
|
6 |
+
import numpy as np
|
7 |
+
from tqdm.asyncio import tqdm as tqdm_async
|
8 |
+
from dataclasses import dataclass
|
9 |
+
|
10 |
+
from lightrag.utils import (
|
11 |
+
logger,
|
12 |
+
compute_mdhash_id,
|
13 |
+
)
|
14 |
+
from lightrag.base import (
|
15 |
+
BaseVectorStorage,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class FaissVectorDBStorage(BaseVectorStorage):
|
21 |
+
"""
|
22 |
+
A Faiss-based Vector DB Storage for LightRAG.
|
23 |
+
Uses cosine similarity by storing normalized vectors in a Faiss index with inner product search.
|
24 |
+
"""
|
25 |
+
cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
|
26 |
+
|
27 |
+
def __post_init__(self):
|
28 |
+
# Grab config values if available
|
29 |
+
config = self.global_config.get("vector_db_storage_cls_kwargs", {})
|
30 |
+
self.cosine_better_than_threshold = config.get(
|
31 |
+
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
32 |
+
)
|
33 |
+
|
34 |
+
# Where to save index file if you want persistent storage
|
35 |
+
self._faiss_index_file = os.path.join(
|
36 |
+
self.global_config["working_dir"], f"faiss_index_{self.namespace}.index"
|
37 |
+
)
|
38 |
+
self._meta_file = self._faiss_index_file + ".meta.json"
|
39 |
+
|
40 |
+
self._max_batch_size = self.global_config["embedding_batch_num"]
|
41 |
+
# Embedding dimension (e.g. 768) must match your embedding function
|
42 |
+
self._dim = self.embedding_func.embedding_dim
|
43 |
+
|
44 |
+
# Create an empty Faiss index for inner product (useful for normalized vectors = cosine similarity).
|
45 |
+
# If you have a large number of vectors, you might want IVF or other indexes.
|
46 |
+
# For demonstration, we use a simple IndexFlatIP.
|
47 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
48 |
+
|
49 |
+
# Keep a local store for metadata, IDs, etc.
|
50 |
+
# Maps <int faiss_id> → metadata (including your original ID).
|
51 |
+
self._id_to_meta = {}
|
52 |
+
|
53 |
+
# Attempt to load an existing index + metadata from disk
|
54 |
+
self._load_faiss_index()
|
55 |
+
|
56 |
+
async def upsert(self, data: dict[str, dict]):
|
57 |
+
"""
|
58 |
+
Insert or update vectors in the Faiss index.
|
59 |
+
|
60 |
+
data: {
|
61 |
+
"custom_id_1": {
|
62 |
+
"content": <text>,
|
63 |
+
...metadata...
|
64 |
+
},
|
65 |
+
"custom_id_2": {
|
66 |
+
"content": <text>,
|
67 |
+
...metadata...
|
68 |
+
},
|
69 |
+
...
|
70 |
+
}
|
71 |
+
"""
|
72 |
+
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
73 |
+
if not data:
|
74 |
+
logger.warning("You are inserting empty data to the vector DB")
|
75 |
+
return []
|
76 |
+
|
77 |
+
current_time = time.time()
|
78 |
+
|
79 |
+
# Prepare data for embedding
|
80 |
+
list_data = []
|
81 |
+
contents = []
|
82 |
+
for k, v in data.items():
|
83 |
+
# Store only known meta fields if needed
|
84 |
+
meta = {mf: v[mf] for mf in self.meta_fields if mf in v}
|
85 |
+
meta["__id__"] = k
|
86 |
+
meta["__created_at__"] = current_time
|
87 |
+
list_data.append(meta)
|
88 |
+
contents.append(v["content"])
|
89 |
+
|
90 |
+
# Split into batches for embedding if needed
|
91 |
+
batches = [
|
92 |
+
contents[i : i + self._max_batch_size]
|
93 |
+
for i in range(0, len(contents), self._max_batch_size)
|
94 |
+
]
|
95 |
+
|
96 |
+
pbar = tqdm_async(total=len(batches), desc="Generating embeddings", unit="batch")
|
97 |
+
|
98 |
+
async def wrapped_task(batch):
|
99 |
+
result = await self.embedding_func(batch)
|
100 |
+
pbar.update(1)
|
101 |
+
return result
|
102 |
+
|
103 |
+
embedding_tasks = [wrapped_task(batch) for batch in batches]
|
104 |
+
embeddings_list = await asyncio.gather(*embedding_tasks)
|
105 |
+
|
106 |
+
# Flatten the list of arrays
|
107 |
+
embeddings = np.concatenate(embeddings_list, axis=0)
|
108 |
+
if len(embeddings) != len(list_data):
|
109 |
+
logger.error(
|
110 |
+
f"Embedding size mismatch. Embeddings: {len(embeddings)}, Data: {len(list_data)}"
|
111 |
+
)
|
112 |
+
return []
|
113 |
+
|
114 |
+
# Normalize embeddings for cosine similarity (in-place)
|
115 |
+
faiss.normalize_L2(embeddings)
|
116 |
+
|
117 |
+
# Upsert logic:
|
118 |
+
# 1. Identify which vectors to remove if they exist
|
119 |
+
# 2. Remove them
|
120 |
+
# 3. Add the new vectors
|
121 |
+
existing_ids_to_remove = []
|
122 |
+
for meta, emb in zip(list_data, embeddings):
|
123 |
+
faiss_internal_id = self._find_faiss_id_by_custom_id(meta["__id__"])
|
124 |
+
if faiss_internal_id is not None:
|
125 |
+
existing_ids_to_remove.append(faiss_internal_id)
|
126 |
+
|
127 |
+
if existing_ids_to_remove:
|
128 |
+
self._remove_faiss_ids(existing_ids_to_remove)
|
129 |
+
|
130 |
+
# Step 2: Add new vectors
|
131 |
+
start_idx = self._index.ntotal
|
132 |
+
self._index.add(embeddings)
|
133 |
+
|
134 |
+
# Step 3: Store metadata + vector for each new ID
|
135 |
+
for i, meta in enumerate(list_data):
|
136 |
+
fid = start_idx + i
|
137 |
+
# Store the raw vector so we can rebuild if something is removed
|
138 |
+
meta["__vector__"] = embeddings[i].tolist()
|
139 |
+
self._id_to_meta[fid] = meta
|
140 |
+
|
141 |
+
logger.info(f"Upserted {len(list_data)} vectors into Faiss index.")
|
142 |
+
return [m["__id__"] for m in list_data]
|
143 |
+
|
144 |
+
async def query(self, query: str, top_k=5):
|
145 |
+
"""
|
146 |
+
Search by a textual query; returns top_k results with their metadata + similarity distance.
|
147 |
+
"""
|
148 |
+
embedding = await self.embedding_func([query])
|
149 |
+
# embedding is shape (1, dim)
|
150 |
+
embedding = np.array(embedding, dtype=np.float32)
|
151 |
+
faiss.normalize_L2(embedding) # we do in-place normalization
|
152 |
+
|
153 |
+
logger.info(
|
154 |
+
f"Query: {query}, top_k: {top_k}, threshold: {self.cosine_better_than_threshold}"
|
155 |
+
)
|
156 |
+
|
157 |
+
# Perform the similarity search
|
158 |
+
distances, indices = self._index.search(embedding, top_k)
|
159 |
+
|
160 |
+
distances = distances[0]
|
161 |
+
indices = indices[0]
|
162 |
+
|
163 |
+
results = []
|
164 |
+
for dist, idx in zip(distances, indices):
|
165 |
+
if idx == -1:
|
166 |
+
# Faiss returns -1 if no neighbor
|
167 |
+
continue
|
168 |
+
|
169 |
+
# Cosine similarity threshold
|
170 |
+
if dist < self.cosine_better_than_threshold:
|
171 |
+
continue
|
172 |
+
|
173 |
+
meta = self._id_to_meta.get(idx, {})
|
174 |
+
results.append(
|
175 |
+
{
|
176 |
+
**meta,
|
177 |
+
"id": meta.get("__id__"),
|
178 |
+
"distance": float(dist),
|
179 |
+
"created_at": meta.get("__created_at__"),
|
180 |
+
}
|
181 |
+
)
|
182 |
+
|
183 |
+
return results
|
184 |
+
|
185 |
+
@property
|
186 |
+
def client_storage(self):
|
187 |
+
# Return whatever structure LightRAG might need for debugging
|
188 |
+
return {"data": list(self._id_to_meta.values())}
|
189 |
+
|
190 |
+
async def delete(self, ids: list[str]):
|
191 |
+
"""
|
192 |
+
Delete vectors for the provided custom IDs.
|
193 |
+
"""
|
194 |
+
logger.info(f"Deleting {len(ids)} vectors from {self.namespace}")
|
195 |
+
to_remove = []
|
196 |
+
for cid in ids:
|
197 |
+
fid = self._find_faiss_id_by_custom_id(cid)
|
198 |
+
if fid is not None:
|
199 |
+
to_remove.append(fid)
|
200 |
+
|
201 |
+
if to_remove:
|
202 |
+
self._remove_faiss_ids(to_remove)
|
203 |
+
logger.info(f"Successfully deleted {len(to_remove)} vectors from {self.namespace}")
|
204 |
+
|
205 |
+
async def delete_entity(self, entity_name: str):
|
206 |
+
"""
|
207 |
+
Delete a single entity by computing its hashed ID
|
208 |
+
the same way your code does it with `compute_mdhash_id`.
|
209 |
+
"""
|
210 |
+
entity_id = compute_mdhash_id(entity_name, prefix="ent-")
|
211 |
+
logger.debug(f"Attempting to delete entity {entity_name} with ID {entity_id}")
|
212 |
+
await self.delete([entity_id])
|
213 |
+
|
214 |
+
async def delete_entity_relation(self, entity_name: str):
|
215 |
+
"""
|
216 |
+
Delete relations for a given entity by scanning metadata.
|
217 |
+
"""
|
218 |
+
logger.debug(f"Searching relations for entity {entity_name}")
|
219 |
+
relations = []
|
220 |
+
for fid, meta in self._id_to_meta.items():
|
221 |
+
if meta.get("src_id") == entity_name or meta.get("tgt_id") == entity_name:
|
222 |
+
relations.append(fid)
|
223 |
+
|
224 |
+
logger.debug(f"Found {len(relations)} relations for {entity_name}")
|
225 |
+
if relations:
|
226 |
+
self._remove_faiss_ids(relations)
|
227 |
+
logger.debug(f"Deleted {len(relations)} relations for {entity_name}")
|
228 |
+
|
229 |
+
async def index_done_callback(self):
|
230 |
+
"""
|
231 |
+
Called after indexing is done (save Faiss index + metadata).
|
232 |
+
"""
|
233 |
+
self._save_faiss_index()
|
234 |
+
logger.info("Faiss index saved successfully.")
|
235 |
+
|
236 |
+
# --------------------------------------------------------------------------------
|
237 |
+
# Internal helper methods
|
238 |
+
# --------------------------------------------------------------------------------
|
239 |
+
|
240 |
+
def _find_faiss_id_by_custom_id(self, custom_id: str):
|
241 |
+
"""
|
242 |
+
Return the Faiss internal ID for a given custom ID, or None if not found.
|
243 |
+
"""
|
244 |
+
for fid, meta in self._id_to_meta.items():
|
245 |
+
if meta.get("__id__") == custom_id:
|
246 |
+
return fid
|
247 |
+
return None
|
248 |
+
|
249 |
+
def _remove_faiss_ids(self, fid_list):
|
250 |
+
"""
|
251 |
+
Remove a list of internal Faiss IDs from the index.
|
252 |
+
Because IndexFlatIP doesn't support 'removals',
|
253 |
+
we rebuild the index excluding those vectors.
|
254 |
+
"""
|
255 |
+
keep_fids = [fid for fid in self._id_to_meta if fid not in fid_list]
|
256 |
+
|
257 |
+
# Rebuild the index
|
258 |
+
vectors_to_keep = []
|
259 |
+
new_id_to_meta = {}
|
260 |
+
for new_fid, old_fid in enumerate(keep_fids):
|
261 |
+
vec_meta = self._id_to_meta[old_fid]
|
262 |
+
vectors_to_keep.append(vec_meta["__vector__"]) # stored as list
|
263 |
+
new_id_to_meta[new_fid] = vec_meta
|
264 |
+
|
265 |
+
# Re-init index
|
266 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
267 |
+
if vectors_to_keep:
|
268 |
+
arr = np.array(vectors_to_keep, dtype=np.float32)
|
269 |
+
self._index.add(arr)
|
270 |
+
|
271 |
+
self._id_to_meta = new_id_to_meta
|
272 |
+
|
273 |
+
def _save_faiss_index(self):
|
274 |
+
"""
|
275 |
+
Save the current Faiss index + metadata to disk so it can persist across runs.
|
276 |
+
"""
|
277 |
+
faiss.write_index(self._index, self._faiss_index_file)
|
278 |
+
|
279 |
+
# Save metadata dict to JSON. Convert all keys to strings for JSON storage.
|
280 |
+
# _id_to_meta is { int: { '__id__': doc_id, '__vector__': [float,...], ... } }
|
281 |
+
# We'll keep the int -> dict, but JSON requires string keys.
|
282 |
+
serializable_dict = {}
|
283 |
+
for fid, meta in self._id_to_meta.items():
|
284 |
+
serializable_dict[str(fid)] = meta
|
285 |
+
|
286 |
+
with open(self._meta_file, "w", encoding="utf-8") as f:
|
287 |
+
json.dump(serializable_dict, f)
|
288 |
+
|
289 |
+
def _load_faiss_index(self):
|
290 |
+
"""
|
291 |
+
Load the Faiss index + metadata from disk if it exists,
|
292 |
+
and rebuild in-memory structures so we can query.
|
293 |
+
"""
|
294 |
+
if not os.path.exists(self._faiss_index_file):
|
295 |
+
logger.warning("No existing Faiss index file found. Starting fresh.")
|
296 |
+
return
|
297 |
+
|
298 |
+
try:
|
299 |
+
# Load the Faiss index
|
300 |
+
self._index = faiss.read_index(self._faiss_index_file)
|
301 |
+
# Load metadata
|
302 |
+
with open(self._meta_file, "r", encoding="utf-8") as f:
|
303 |
+
stored_dict = json.load(f)
|
304 |
+
|
305 |
+
# Convert string keys back to int
|
306 |
+
self._id_to_meta = {}
|
307 |
+
for fid_str, meta in stored_dict.items():
|
308 |
+
fid = int(fid_str)
|
309 |
+
self._id_to_meta[fid] = meta
|
310 |
+
|
311 |
+
logger.info(
|
312 |
+
f"Faiss index loaded with {self._index.ntotal} vectors from {self._faiss_index_file}"
|
313 |
+
)
|
314 |
+
except Exception as e:
|
315 |
+
logger.error(f"Failed to load Faiss index or metadata: {e}")
|
316 |
+
logger.warning("Starting with an empty Faiss index.")
|
317 |
+
self._index = faiss.IndexFlatIP(self._dim)
|
318 |
+
self._id_to_meta = {}
|
lightrag/lightrag.py
CHANGED
@@ -60,6 +60,7 @@ STORAGES = {
|
|
60 |
"PGGraphStorage": ".kg.postgres_impl",
|
61 |
"GremlinStorage": ".kg.gremlin_impl",
|
62 |
"PGDocStatusStorage": ".kg.postgres_impl",
|
|
|
63 |
}
|
64 |
|
65 |
|
|
|
60 |
"PGGraphStorage": ".kg.postgres_impl",
|
61 |
"GremlinStorage": ".kg.gremlin_impl",
|
62 |
"PGDocStatusStorage": ".kg.postgres_impl",
|
63 |
+
"FaissVectorDBStorage": ".kg.faiss_impl",
|
64 |
}
|
65 |
|
66 |
|