pre-commit fix tidb
Browse files- lightrag/kg/tidb_impl.py +70 -62
- lightrag/lightrag.py +1 -0
- requirements.txt +2 -2
lightrag/kg/tidb_impl.py
CHANGED
@@ -19,8 +19,10 @@ class TiDB(object):
|
|
19 |
self.password = config.get("password", None)
|
20 |
self.database = config.get("database", None)
|
21 |
self.workspace = config.get("workspace", None)
|
22 |
-
connection_string = (
|
23 |
-
|
|
|
|
|
24 |
|
25 |
try:
|
26 |
self.engine = create_engine(connection_string)
|
@@ -49,7 +51,7 @@ class TiDB(object):
|
|
49 |
self, sql: str, params: dict = None, multirows: bool = False
|
50 |
) -> Union[dict, None]:
|
51 |
if params is None:
|
52 |
-
params = {
|
53 |
else:
|
54 |
params.update({"workspace": self.workspace})
|
55 |
with self.engine.connect() as conn, conn.begin():
|
@@ -130,8 +132,8 @@ class TiDBKVStorage(BaseKVStorage):
|
|
130 |
"""过滤掉重复内容"""
|
131 |
SQL = SQL_TEMPLATES["filter_keys"].format(
|
132 |
table_name=N_T[self.namespace],
|
133 |
-
id_field=
|
134 |
-
ids=",".join([f"'{id}'" for id in keys])
|
135 |
)
|
136 |
try:
|
137 |
await self.db.query(SQL)
|
@@ -161,7 +163,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|
161 |
]
|
162 |
contents = [v["content"] for v in data.values()]
|
163 |
batches = [
|
164 |
-
contents[i: i + self._max_batch_size]
|
165 |
for i in range(0, len(contents), self._max_batch_size)
|
166 |
]
|
167 |
embeddings_list = await asyncio.gather(
|
@@ -174,26 +176,30 @@ class TiDBKVStorage(BaseKVStorage):
|
|
174 |
merge_sql = SQL_TEMPLATES["upsert_chunk"]
|
175 |
data = []
|
176 |
for item in list_data:
|
177 |
-
data.append(
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
186 |
await self.db.execute(merge_sql, data)
|
187 |
|
188 |
if self.namespace == "full_docs":
|
189 |
merge_sql = SQL_TEMPLATES["upsert_doc_full"]
|
190 |
data = []
|
191 |
for k, v in self._data.items():
|
192 |
-
data.append(
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
|
|
|
|
197 |
await self.db.execute(merge_sql, data)
|
198 |
return left_data
|
199 |
|
@@ -201,6 +207,7 @@ class TiDBKVStorage(BaseKVStorage):
|
|
201 |
if self.namespace in ["full_docs", "text_chunks"]:
|
202 |
logger.info("full doc and chunk data had been saved into TiDB db!")
|
203 |
|
|
|
204 |
@dataclass
|
205 |
class TiDBVectorDBStorage(BaseVectorStorage):
|
206 |
cosine_better_than_threshold: float = 0.2
|
@@ -215,7 +222,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
215 |
)
|
216 |
|
217 |
async def query(self, query: str, top_k: int) -> list[dict]:
|
218 |
-
"""
|
219 |
|
220 |
embeddings = await self.embedding_func([query])
|
221 |
embedding = embeddings[0]
|
@@ -228,8 +235,10 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
228 |
"better_than_threshold": self.cosine_better_than_threshold,
|
229 |
}
|
230 |
|
231 |
-
results = await self.db.query(
|
232 |
-
|
|
|
|
|
233 |
if not results:
|
234 |
return []
|
235 |
return results
|
@@ -253,16 +262,16 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
253 |
]
|
254 |
contents = [v["content"] for v in data.values()]
|
255 |
batches = [
|
256 |
-
contents[i: i + self._max_batch_size]
|
257 |
for i in range(0, len(contents), self._max_batch_size)
|
258 |
]
|
259 |
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
260 |
embeddings_list = []
|
261 |
for f in tqdm(
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
):
|
267 |
embeddings = await f
|
268 |
embeddings_list.append(embeddings)
|
@@ -274,27 +283,31 @@ class TiDBVectorDBStorage(BaseVectorStorage):
|
|
274 |
data = []
|
275 |
for item in list_data:
|
276 |
merge_sql = SQL_TEMPLATES["upsert_entity"]
|
277 |
-
data.append(
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
|
|
|
|
284 |
await self.db.execute(merge_sql, data)
|
285 |
|
286 |
elif self.namespace == "relationships":
|
287 |
data = []
|
288 |
for item in list_data:
|
289 |
merge_sql = SQL_TEMPLATES["upsert_relationship"]
|
290 |
-
data.append(
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
|
|
|
|
298 |
await self.db.execute(merge_sql, data)
|
299 |
|
300 |
|
@@ -346,8 +359,7 @@ TABLES = {
|
|
346 |
"""
|
347 |
},
|
348 |
"LIGHTRAG_GRAPH_NODES": {
|
349 |
-
"ddl":
|
350 |
-
"""
|
351 |
CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
352 |
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
353 |
`entity_id` VARCHAR(256) NOT NULL,
|
@@ -362,8 +374,7 @@ TABLES = {
|
|
362 |
"""
|
363 |
},
|
364 |
"LIGHTRAG_GRAPH_EDGES": {
|
365 |
-
"ddl":
|
366 |
-
"""
|
367 |
CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
368 |
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
369 |
`relation_id` VARCHAR(256) NOT NULL,
|
@@ -400,7 +411,6 @@ SQL_TEMPLATES = {
|
|
400 |
"get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
|
401 |
"get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
|
402 |
"filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
|
403 |
-
|
404 |
# SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
|
405 |
"upsert_doc_full": """
|
406 |
INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
|
@@ -408,13 +418,12 @@ SQL_TEMPLATES = {
|
|
408 |
ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
409 |
""",
|
410 |
"upsert_chunk": """
|
411 |
-
INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
|
412 |
VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
|
413 |
-
ON DUPLICATE KEY UPDATE
|
414 |
-
content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
|
415 |
full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
416 |
""",
|
417 |
-
|
418 |
# SQL for VectorStorage
|
419 |
"entities": """SELECT n.name as entity_name FROM
|
420 |
(SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
|
@@ -428,19 +437,18 @@ SQL_TEMPLATES = {
|
|
428 |
(SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
|
429 |
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
|
430 |
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""",
|
431 |
-
|
432 |
"upsert_entity": """
|
433 |
-
INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
|
434 |
-
VALUES(:id, :name, :content, :content_vector, :workspace)
|
435 |
-
ON DUPLICATE KEY UPDATE
|
436 |
-
name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
|
437 |
workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
438 |
""",
|
439 |
"upsert_relationship": """
|
440 |
-
INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
|
441 |
VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
|
442 |
-
ON DUPLICATE KEY UPDATE
|
443 |
-
source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
|
444 |
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
445 |
-
"""
|
446 |
-
}
|
|
|
19 |
self.password = config.get("password", None)
|
20 |
self.database = config.get("database", None)
|
21 |
self.workspace = config.get("workspace", None)
|
22 |
+
connection_string = (
|
23 |
+
f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
24 |
+
f"?ssl_verify_cert=true&ssl_verify_identity=true"
|
25 |
+
)
|
26 |
|
27 |
try:
|
28 |
self.engine = create_engine(connection_string)
|
|
|
51 |
self, sql: str, params: dict = None, multirows: bool = False
|
52 |
) -> Union[dict, None]:
|
53 |
if params is None:
|
54 |
+
params = {"workspace": self.workspace}
|
55 |
else:
|
56 |
params.update({"workspace": self.workspace})
|
57 |
with self.engine.connect() as conn, conn.begin():
|
|
|
132 |
"""过滤掉重复内容"""
|
133 |
SQL = SQL_TEMPLATES["filter_keys"].format(
|
134 |
table_name=N_T[self.namespace],
|
135 |
+
id_field=N_ID[self.namespace],
|
136 |
+
ids=",".join([f"'{id}'" for id in keys]),
|
137 |
)
|
138 |
try:
|
139 |
await self.db.query(SQL)
|
|
|
163 |
]
|
164 |
contents = [v["content"] for v in data.values()]
|
165 |
batches = [
|
166 |
+
contents[i : i + self._max_batch_size]
|
167 |
for i in range(0, len(contents), self._max_batch_size)
|
168 |
]
|
169 |
embeddings_list = await asyncio.gather(
|
|
|
176 |
merge_sql = SQL_TEMPLATES["upsert_chunk"]
|
177 |
data = []
|
178 |
for item in list_data:
|
179 |
+
data.append(
|
180 |
+
{
|
181 |
+
"id": item["__id__"],
|
182 |
+
"content": item["content"],
|
183 |
+
"tokens": item["tokens"],
|
184 |
+
"chunk_order_index": item["chunk_order_index"],
|
185 |
+
"full_doc_id": item["full_doc_id"],
|
186 |
+
"content_vector": f"{item["__vector__"].tolist()}",
|
187 |
+
"workspace": self.db.workspace,
|
188 |
+
}
|
189 |
+
)
|
190 |
await self.db.execute(merge_sql, data)
|
191 |
|
192 |
if self.namespace == "full_docs":
|
193 |
merge_sql = SQL_TEMPLATES["upsert_doc_full"]
|
194 |
data = []
|
195 |
for k, v in self._data.items():
|
196 |
+
data.append(
|
197 |
+
{
|
198 |
+
"id": k,
|
199 |
+
"content": v["content"],
|
200 |
+
"workspace": self.db.workspace,
|
201 |
+
}
|
202 |
+
)
|
203 |
await self.db.execute(merge_sql, data)
|
204 |
return left_data
|
205 |
|
|
|
207 |
if self.namespace in ["full_docs", "text_chunks"]:
|
208 |
logger.info("full doc and chunk data had been saved into TiDB db!")
|
209 |
|
210 |
+
|
211 |
@dataclass
|
212 |
class TiDBVectorDBStorage(BaseVectorStorage):
|
213 |
cosine_better_than_threshold: float = 0.2
|
|
|
222 |
)
|
223 |
|
224 |
async def query(self, query: str, top_k: int) -> list[dict]:
|
225 |
+
"""search from tidb vector"""
|
226 |
|
227 |
embeddings = await self.embedding_func([query])
|
228 |
embedding = embeddings[0]
|
|
|
235 |
"better_than_threshold": self.cosine_better_than_threshold,
|
236 |
}
|
237 |
|
238 |
+
results = await self.db.query(
|
239 |
+
SQL_TEMPLATES[self.namespace], params=params, multirows=True
|
240 |
+
)
|
241 |
+
print("vector search result:", results)
|
242 |
if not results:
|
243 |
return []
|
244 |
return results
|
|
|
262 |
]
|
263 |
contents = [v["content"] for v in data.values()]
|
264 |
batches = [
|
265 |
+
contents[i : i + self._max_batch_size]
|
266 |
for i in range(0, len(contents), self._max_batch_size)
|
267 |
]
|
268 |
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
269 |
embeddings_list = []
|
270 |
for f in tqdm(
|
271 |
+
asyncio.as_completed(embedding_tasks),
|
272 |
+
total=len(embedding_tasks),
|
273 |
+
desc="Generating embeddings",
|
274 |
+
unit="batch",
|
275 |
):
|
276 |
embeddings = await f
|
277 |
embeddings_list.append(embeddings)
|
|
|
283 |
data = []
|
284 |
for item in list_data:
|
285 |
merge_sql = SQL_TEMPLATES["upsert_entity"]
|
286 |
+
data.append(
|
287 |
+
{
|
288 |
+
"id": item["id"],
|
289 |
+
"name": item["entity_name"],
|
290 |
+
"content": item["content"],
|
291 |
+
"content_vector": f"{item["content_vector"].tolist()}",
|
292 |
+
"workspace": self.db.workspace,
|
293 |
+
}
|
294 |
+
)
|
295 |
await self.db.execute(merge_sql, data)
|
296 |
|
297 |
elif self.namespace == "relationships":
|
298 |
data = []
|
299 |
for item in list_data:
|
300 |
merge_sql = SQL_TEMPLATES["upsert_relationship"]
|
301 |
+
data.append(
|
302 |
+
{
|
303 |
+
"id": item["id"],
|
304 |
+
"source_name": item["src_id"],
|
305 |
+
"target_name": item["tgt_id"],
|
306 |
+
"content": item["content"],
|
307 |
+
"content_vector": f"{item["content_vector"].tolist()}",
|
308 |
+
"workspace": self.db.workspace,
|
309 |
+
}
|
310 |
+
)
|
311 |
await self.db.execute(merge_sql, data)
|
312 |
|
313 |
|
|
|
359 |
"""
|
360 |
},
|
361 |
"LIGHTRAG_GRAPH_NODES": {
|
362 |
+
"ddl": """
|
|
|
363 |
CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
364 |
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
365 |
`entity_id` VARCHAR(256) NOT NULL,
|
|
|
374 |
"""
|
375 |
},
|
376 |
"LIGHTRAG_GRAPH_EDGES": {
|
377 |
+
"ddl": """
|
|
|
378 |
CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
379 |
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
380 |
`relation_id` VARCHAR(256) NOT NULL,
|
|
|
411 |
"get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
|
412 |
"get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
|
413 |
"filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
|
|
|
414 |
# SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
|
415 |
"upsert_doc_full": """
|
416 |
INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
|
|
|
418 |
ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
419 |
""",
|
420 |
"upsert_chunk": """
|
421 |
+
INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
|
422 |
VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
|
423 |
+
ON DUPLICATE KEY UPDATE
|
424 |
+
content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
|
425 |
full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
426 |
""",
|
|
|
427 |
# SQL for VectorStorage
|
428 |
"entities": """SELECT n.name as entity_name FROM
|
429 |
(SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
|
|
|
437 |
(SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
|
438 |
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
|
439 |
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""",
|
|
|
440 |
"upsert_entity": """
|
441 |
+
INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
|
442 |
+
VALUES(:id, :name, :content, :content_vector, :workspace)
|
443 |
+
ON DUPLICATE KEY UPDATE
|
444 |
+
name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
|
445 |
workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
446 |
""",
|
447 |
"upsert_relationship": """
|
448 |
+
INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
|
449 |
VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
|
450 |
+
ON DUPLICATE KEY UPDATE
|
451 |
+
source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
|
452 |
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
453 |
+
""",
|
454 |
+
}
|
lightrag/lightrag.py
CHANGED
@@ -80,6 +80,7 @@ ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBS
|
|
80 |
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
|
81 |
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
|
82 |
|
|
|
83 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
84 |
"""
|
85 |
Ensure that there is always an event loop available.
|
|
|
80 |
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
|
81 |
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
|
82 |
|
83 |
+
|
84 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
85 |
"""
|
86 |
Ensure that there is always an event loop available.
|
requirements.txt
CHANGED
@@ -13,11 +13,11 @@ openai
|
|
13 |
oracledb
|
14 |
pymilvus
|
15 |
pymongo
|
|
|
16 |
pyvis
|
17 |
-
tenacity
|
18 |
# lmdeploy[all]
|
19 |
sqlalchemy
|
20 |
-
|
21 |
|
22 |
|
23 |
# LLM packages
|
|
|
13 |
oracledb
|
14 |
pymilvus
|
15 |
pymongo
|
16 |
+
pymysql
|
17 |
pyvis
|
|
|
18 |
# lmdeploy[all]
|
19 |
sqlalchemy
|
20 |
+
tenacity
|
21 |
|
22 |
|
23 |
# LLM packages
|