jin
commited on
Commit
·
c77f948
1
Parent(s):
5dcb28f
use oracle bind variables to avoid error
Browse files- lightrag/base.py +1 -0
- lightrag/kg/oracle_impl.py +175 -143
- lightrag/operate.py +13 -3
- lightrag/utils.py +5 -1
lightrag/base.py
CHANGED
@@ -17,6 +17,7 @@ T = TypeVar("T")
|
|
17 |
class QueryParam:
|
18 |
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
19 |
only_need_context: bool = False
|
|
|
20 |
response_type: str = "Multiple Paragraphs"
|
21 |
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
22 |
top_k: int = 60
|
|
|
17 |
class QueryParam:
|
18 |
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
19 |
only_need_context: bool = False
|
20 |
+
only_need_prompt: bool = False
|
21 |
response_type: str = "Multiple Paragraphs"
|
22 |
# Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
|
23 |
top_k: int = 60
|
lightrag/kg/oracle_impl.py
CHANGED
@@ -114,16 +114,17 @@ class OracleDB:
|
|
114 |
|
115 |
logger.info("Finished check all tables in Oracle database")
|
116 |
|
117 |
-
async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
|
118 |
async with self.pool.acquire() as connection:
|
119 |
connection.inputtypehandler = self.input_type_handler
|
120 |
connection.outputtypehandler = self.output_type_handler
|
121 |
with connection.cursor() as cursor:
|
122 |
try:
|
123 |
-
await cursor.execute(sql)
|
124 |
except Exception as e:
|
125 |
logger.error(f"Oracle database error: {e}")
|
126 |
print(sql)
|
|
|
127 |
raise
|
128 |
columns = [column[0].lower() for column in cursor.description]
|
129 |
if multirows:
|
@@ -140,7 +141,7 @@ class OracleDB:
|
|
140 |
data = None
|
141 |
return data
|
142 |
|
143 |
-
async def execute(self, sql: str, data: list = None):
|
144 |
# logger.info("go into OracleDB execute method")
|
145 |
try:
|
146 |
async with self.pool.acquire() as connection:
|
@@ -172,11 +173,10 @@ class OracleKVStorage(BaseKVStorage):
|
|
172 |
|
173 |
async def get_by_id(self, id: str) -> Union[dict, None]:
|
174 |
"""根据 id 获取 doc_full 数据."""
|
175 |
-
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
176 |
-
|
177 |
-
)
|
178 |
# print("get_by_id:"+SQL)
|
179 |
-
res = await self.db.query(SQL)
|
180 |
if res:
|
181 |
data = res # {"data":res}
|
182 |
# print (data)
|
@@ -187,11 +187,11 @@ class OracleKVStorage(BaseKVStorage):
|
|
187 |
# Query by id
|
188 |
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
189 |
"""根据 id 获取 doc_chunks 数据"""
|
190 |
-
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
191 |
-
|
192 |
-
)
|
193 |
-
#
|
194 |
-
res = await self.db.query(SQL, multirows=True)
|
195 |
if res:
|
196 |
data = res # [{"data":i} for i in res]
|
197 |
# print(data)
|
@@ -201,12 +201,16 @@ class OracleKVStorage(BaseKVStorage):
|
|
201 |
|
202 |
async def filter_keys(self, keys: list[str]) -> set[str]:
|
203 |
"""过滤掉重复内容"""
|
204 |
-
SQL = SQL_TEMPLATES["filter_keys"].format(
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
210 |
data = None
|
211 |
if res:
|
212 |
exist_keys = [key["id"] for key in res]
|
@@ -243,29 +247,31 @@ class OracleKVStorage(BaseKVStorage):
|
|
243 |
d["__vector__"] = embeddings[i]
|
244 |
# print(list_data)
|
245 |
for item in list_data:
|
246 |
-
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
]
|
257 |
# print(merge_sql)
|
258 |
-
await self.db.execute(merge_sql,
|
259 |
|
260 |
if self.namespace == "full_docs":
|
261 |
for k, v in self._data.items():
|
262 |
# values.clear()
|
263 |
-
merge_sql = SQL_TEMPLATES["merge_doc_full"]
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
267 |
# print(merge_sql)
|
268 |
-
await self.db.execute(merge_sql,
|
269 |
return left_data
|
270 |
|
271 |
async def index_done_callback(self):
|
@@ -295,18 +301,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
|
|
295 |
# 转换精度
|
296 |
dtype = str(embedding.dtype).upper()
|
297 |
dimension = embedding.shape[0]
|
298 |
-
embedding_string = ", ".join(map(str, embedding.tolist()))
|
299 |
-
|
300 |
-
SQL = SQL_TEMPLATES[self.namespace].format(
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
)
|
308 |
# print(SQL)
|
309 |
-
results = await self.db.query(SQL, multirows=True)
|
310 |
# print("vector search result:",results)
|
311 |
return results
|
312 |
|
@@ -339,22 +344,18 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
339 |
)
|
340 |
embeddings = np.concatenate(embeddings_list)
|
341 |
content_vector = embeddings[0]
|
342 |
-
merge_sql = SQL_TEMPLATES["merge_node"]
|
343 |
-
|
344 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
# print(merge_sql)
|
346 |
-
await self.db.execute(
|
347 |
-
merge_sql,
|
348 |
-
[
|
349 |
-
self.db.workspace,
|
350 |
-
entity_name,
|
351 |
-
entity_type,
|
352 |
-
description,
|
353 |
-
source_id,
|
354 |
-
content,
|
355 |
-
content_vector,
|
356 |
-
],
|
357 |
-
)
|
358 |
# self._graph.add_node(node_id, **node_data)
|
359 |
|
360 |
async def upsert_edge(
|
@@ -379,27 +380,20 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
379 |
)
|
380 |
embeddings = np.concatenate(embeddings_list)
|
381 |
content_vector = embeddings[0]
|
382 |
-
merge_sql = SQL_TEMPLATES["merge_edge"]
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
# print(merge_sql)
|
389 |
-
await self.db.execute(
|
390 |
-
merge_sql,
|
391 |
-
[
|
392 |
-
self.db.workspace,
|
393 |
-
source_name,
|
394 |
-
target_name,
|
395 |
-
weight,
|
396 |
-
keywords,
|
397 |
-
description,
|
398 |
-
source_chunk_id,
|
399 |
-
content,
|
400 |
-
content_vector,
|
401 |
-
],
|
402 |
-
)
|
403 |
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
404 |
|
405 |
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
@@ -429,12 +423,14 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
429 |
#################### query method #################
|
430 |
async def has_node(self, node_id: str) -> bool:
|
431 |
"""根据节点id检查节点是否存在"""
|
432 |
-
SQL = SQL_TEMPLATES["has_node"]
|
433 |
-
|
434 |
-
|
|
|
|
|
435 |
# print(SQL)
|
436 |
# print(self.db.workspace, node_id)
|
437 |
-
res = await self.db.query(SQL)
|
438 |
if res:
|
439 |
# print("Node exist!",res)
|
440 |
return True
|
@@ -444,13 +440,14 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
444 |
|
445 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
446 |
"""根据源和目标节点id检查边是否存在"""
|
447 |
-
SQL = SQL_TEMPLATES["has_edge"]
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
|
|
452 |
# print(SQL)
|
453 |
-
res = await self.db.query(SQL)
|
454 |
if res:
|
455 |
# print("Edge exist!",res)
|
456 |
return True
|
@@ -460,11 +457,13 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
460 |
|
461 |
async def node_degree(self, node_id: str) -> int:
|
462 |
"""根据节点id获取节点的度"""
|
463 |
-
SQL = SQL_TEMPLATES["node_degree"]
|
464 |
-
|
465 |
-
|
|
|
|
|
466 |
# print(SQL)
|
467 |
-
res = await self.db.query(SQL)
|
468 |
if res:
|
469 |
# print("Node degree",res["degree"])
|
470 |
return res["degree"]
|
@@ -480,12 +479,14 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
480 |
|
481 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
482 |
"""根据节点id获取节点数据"""
|
483 |
-
SQL = SQL_TEMPLATES["get_node"]
|
484 |
-
|
485 |
-
|
|
|
|
|
486 |
# print(self.db.workspace, node_id)
|
487 |
# print(SQL)
|
488 |
-
res = await self.db.query(SQL)
|
489 |
if res:
|
490 |
# print("Get node!",self.db.workspace, node_id,res)
|
491 |
return res
|
@@ -497,12 +498,13 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
497 |
self, source_node_id: str, target_node_id: str
|
498 |
) -> Union[dict, None]:
|
499 |
"""根据源和目标节点id获取边"""
|
500 |
-
SQL = SQL_TEMPLATES["get_edge"]
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
|
|
506 |
if res:
|
507 |
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
508 |
return res
|
@@ -513,10 +515,12 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
513 |
async def get_node_edges(self, source_node_id: str):
|
514 |
"""根据节点id获取节点的所有边"""
|
515 |
if await self.has_node(source_node_id):
|
516 |
-
SQL = SQL_TEMPLATES["get_node_edges"]
|
517 |
-
|
518 |
-
|
519 |
-
|
|
|
|
|
520 |
if res:
|
521 |
data = [(i["source_name"], i["target_name"]) for i in res]
|
522 |
# print("Get node edge!",self.db.workspace, source_node_id,data)
|
@@ -524,8 +528,22 @@ class OracleGraphStorage(BaseGraphStorage):
|
|
524 |
else:
|
525 |
# print("Node Edge not exist!",self.db.workspace, source_node_id)
|
526 |
return []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
|
528 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
529 |
N_T = {
|
530 |
"full_docs": "LIGHTRAG_DOC_FULL",
|
531 |
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
@@ -619,82 +637,96 @@ TABLES = {
|
|
619 |
|
620 |
SQL_TEMPLATES = {
|
621 |
# SQL for KVStorage
|
622 |
-
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace
|
623 |
-
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace
|
624 |
-
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace
|
625 |
-
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace
|
626 |
-
"filter_keys": "select id from {table_name} where workspace
|
627 |
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
|
628 |
USING DUAL
|
629 |
-
ON (a.id =
|
630 |
WHEN NOT MATCHED THEN
|
631 |
-
INSERT(id,content,workspace) values(:
|
632 |
""",
|
633 |
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
|
634 |
USING DUAL
|
635 |
-
ON (a.id =
|
636 |
WHEN NOT MATCHED THEN
|
637 |
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
|
638 |
-
values (:
|
639 |
# SQL for VectorStorage
|
640 |
"entities": """SELECT name as entity_name FROM
|
641 |
-
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector(
|
642 |
-
FROM LIGHTRAG_GRAPH_NODES WHERE workspace
|
643 |
-
WHERE distance
|
644 |
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
|
645 |
-
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(
|
646 |
-
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace
|
647 |
-
WHERE distance
|
648 |
"chunks": """SELECT id FROM
|
649 |
-
(SELECT id,VECTOR_DISTANCE(content_vector,vector(
|
650 |
-
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace
|
651 |
-
WHERE distance
|
652 |
# SQL for GraphStorage
|
653 |
"has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
654 |
MATCH (a)
|
655 |
-
WHERE a.workspace
|
656 |
COLUMNS (a.name))""",
|
657 |
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
658 |
MATCH (a) -[e]-> (b)
|
659 |
-
WHERE e.workspace
|
660 |
-
AND a.name
|
661 |
COLUMNS (e.source_name,e.target_name) )""",
|
662 |
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
|
663 |
MATCH (a)-[e]->(b)
|
664 |
-
WHERE a.workspace
|
665 |
-
AND a.name
|
666 |
COLUMNS (a.name))""",
|
667 |
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
|
668 |
FROM GRAPH_TABLE (lightrag_graph
|
669 |
MATCH (a)
|
670 |
-
WHERE a.workspace
|
671 |
COLUMNS (a.name)
|
672 |
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
|
673 |
-
WHERE t2.workspace
|
674 |
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
|
675 |
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
|
676 |
FROM GRAPH_TABLE (lightrag_graph
|
677 |
MATCH (a)-[e]->(b)
|
678 |
-
WHERE e.workspace
|
679 |
-
AND a.name
|
680 |
COLUMNS (e.id,a.name as source_id)
|
681 |
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
|
682 |
"get_node_edges": """SELECT source_name,target_name
|
683 |
FROM GRAPH_TABLE (lightrag_graph
|
684 |
MATCH (a)-[e]->(b)
|
685 |
-
WHERE e.workspace
|
686 |
-
AND a.name
|
687 |
COLUMNS (a.name as source_name,b.name as target_name))""",
|
688 |
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
|
689 |
USING DUAL
|
690 |
-
ON (a.workspace =
|
691 |
WHEN NOT MATCHED THEN
|
692 |
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
|
693 |
-
values (:
|
694 |
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
|
695 |
USING DUAL
|
696 |
-
ON (a.workspace =
|
697 |
WHEN NOT MATCHED THEN
|
698 |
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
699 |
-
values (:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
700 |
}
|
|
|
114 |
|
115 |
logger.info("Finished check all tables in Oracle database")
|
116 |
|
117 |
+
async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
|
118 |
async with self.pool.acquire() as connection:
|
119 |
connection.inputtypehandler = self.input_type_handler
|
120 |
connection.outputtypehandler = self.output_type_handler
|
121 |
with connection.cursor() as cursor:
|
122 |
try:
|
123 |
+
await cursor.execute(sql, params)
|
124 |
except Exception as e:
|
125 |
logger.error(f"Oracle database error: {e}")
|
126 |
print(sql)
|
127 |
+
print(params)
|
128 |
raise
|
129 |
columns = [column[0].lower() for column in cursor.description]
|
130 |
if multirows:
|
|
|
141 |
data = None
|
142 |
return data
|
143 |
|
144 |
+
async def execute(self, sql: str, data: list | dict = None):
|
145 |
# logger.info("go into OracleDB execute method")
|
146 |
try:
|
147 |
async with self.pool.acquire() as connection:
|
|
|
173 |
|
174 |
async def get_by_id(self, id: str) -> Union[dict, None]:
|
175 |
"""根据 id 获取 doc_full 数据."""
|
176 |
+
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
177 |
+
params = {"workspace":self.db.workspace, "id":id}
|
|
|
178 |
# print("get_by_id:"+SQL)
|
179 |
+
res = await self.db.query(SQL,params)
|
180 |
if res:
|
181 |
data = res # {"data":res}
|
182 |
# print (data)
|
|
|
187 |
# Query by id
|
188 |
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
189 |
"""根据 id 获取 doc_chunks 数据"""
|
190 |
+
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
|
191 |
+
params = {"workspace":self.db.workspace}
|
192 |
+
#print("get_by_ids:"+SQL)
|
193 |
+
#print(params)
|
194 |
+
res = await self.db.query(SQL,params, multirows=True)
|
195 |
if res:
|
196 |
data = res # [{"data":i} for i in res]
|
197 |
# print(data)
|
|
|
201 |
|
202 |
async def filter_keys(self, keys: list[str]) -> set[str]:
|
203 |
"""过滤掉重复内容"""
|
204 |
+
SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
|
205 |
+
ids=",".join([f"'{id}'" for id in keys]))
|
206 |
+
params = {"workspace":self.db.workspace}
|
207 |
+
try:
|
208 |
+
await self.db.query(SQL, params)
|
209 |
+
except Exception as e:
|
210 |
+
logger.error(f"Oracle database error: {e}")
|
211 |
+
print(SQL)
|
212 |
+
print(params)
|
213 |
+
res = await self.db.query(SQL, params,multirows=True)
|
214 |
data = None
|
215 |
if res:
|
216 |
exist_keys = [key["id"] for key in res]
|
|
|
247 |
d["__vector__"] = embeddings[i]
|
248 |
# print(list_data)
|
249 |
for item in list_data:
|
250 |
+
merge_sql = SQL_TEMPLATES["merge_chunk"]
|
251 |
+
data = {"check_id":item["__id__"],
|
252 |
+
"id":item["__id__"],
|
253 |
+
"content":item["content"],
|
254 |
+
"workspace":self.db.workspace,
|
255 |
+
"tokens":item["tokens"],
|
256 |
+
"chunk_order_index":item["chunk_order_index"],
|
257 |
+
"full_doc_id":item["full_doc_id"],
|
258 |
+
"content_vector":item["__vector__"]
|
259 |
+
}
|
|
|
260 |
# print(merge_sql)
|
261 |
+
await self.db.execute(merge_sql, data)
|
262 |
|
263 |
if self.namespace == "full_docs":
|
264 |
for k, v in self._data.items():
|
265 |
# values.clear()
|
266 |
+
merge_sql = SQL_TEMPLATES["merge_doc_full"]
|
267 |
+
data = {
|
268 |
+
"check_id":k,
|
269 |
+
"id":k,
|
270 |
+
"content":v["content"],
|
271 |
+
"workspace":self.db.workspace
|
272 |
+
}
|
273 |
# print(merge_sql)
|
274 |
+
await self.db.execute(merge_sql, data)
|
275 |
return left_data
|
276 |
|
277 |
async def index_done_callback(self):
|
|
|
301 |
# 转换精度
|
302 |
dtype = str(embedding.dtype).upper()
|
303 |
dimension = embedding.shape[0]
|
304 |
+
embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
|
305 |
+
|
306 |
+
SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
|
307 |
+
params = {
|
308 |
+
"embedding_string": embedding_string,
|
309 |
+
"workspace": self.db.workspace,
|
310 |
+
"top_k": top_k,
|
311 |
+
"better_than_threshold": self.cosine_better_than_threshold,
|
312 |
+
}
|
|
|
313 |
# print(SQL)
|
314 |
+
results = await self.db.query(SQL,params=params, multirows=True)
|
315 |
# print("vector search result:",results)
|
316 |
return results
|
317 |
|
|
|
344 |
)
|
345 |
embeddings = np.concatenate(embeddings_list)
|
346 |
content_vector = embeddings[0]
|
347 |
+
merge_sql = SQL_TEMPLATES["merge_node"]
|
348 |
+
data = {
|
349 |
+
"workspace":self.db.workspace,
|
350 |
+
"name":entity_name,
|
351 |
+
"entity_type":entity_type,
|
352 |
+
"description":description,
|
353 |
+
"source_chunk_id":source_id,
|
354 |
+
"content":content,
|
355 |
+
"content_vector":content_vector
|
356 |
+
}
|
357 |
# print(merge_sql)
|
358 |
+
await self.db.execute(merge_sql,data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
359 |
# self._graph.add_node(node_id, **node_data)
|
360 |
|
361 |
async def upsert_edge(
|
|
|
380 |
)
|
381 |
embeddings = np.concatenate(embeddings_list)
|
382 |
content_vector = embeddings[0]
|
383 |
+
merge_sql = SQL_TEMPLATES["merge_edge"]
|
384 |
+
data = {
|
385 |
+
"workspace":self.db.workspace,
|
386 |
+
"source_name":source_name,
|
387 |
+
"target_name":target_name,
|
388 |
+
"weight":weight,
|
389 |
+
"keywords":keywords,
|
390 |
+
"description":description,
|
391 |
+
"source_chunk_id":source_chunk_id,
|
392 |
+
"content":content,
|
393 |
+
"content_vector":content_vector
|
394 |
+
}
|
395 |
# print(merge_sql)
|
396 |
+
await self.db.execute(merge_sql,data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
# self._graph.add_edge(source_node_id, target_node_id, **edge_data)
|
398 |
|
399 |
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
|
|
423 |
#################### query method #################
|
424 |
async def has_node(self, node_id: str) -> bool:
|
425 |
"""根据节点id检查节点是否存在"""
|
426 |
+
SQL = SQL_TEMPLATES["has_node"]
|
427 |
+
params = {
|
428 |
+
"workspace":self.db.workspace,
|
429 |
+
"node_id":node_id
|
430 |
+
}
|
431 |
# print(SQL)
|
432 |
# print(self.db.workspace, node_id)
|
433 |
+
res = await self.db.query(SQL,params)
|
434 |
if res:
|
435 |
# print("Node exist!",res)
|
436 |
return True
|
|
|
440 |
|
441 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
442 |
"""根据源和目标节点id检查边是否存在"""
|
443 |
+
SQL = SQL_TEMPLATES["has_edge"]
|
444 |
+
params = {
|
445 |
+
"workspace":self.db.workspace,
|
446 |
+
"source_node_id":source_node_id,
|
447 |
+
"target_node_id":target_node_id
|
448 |
+
}
|
449 |
# print(SQL)
|
450 |
+
res = await self.db.query(SQL,params)
|
451 |
if res:
|
452 |
# print("Edge exist!",res)
|
453 |
return True
|
|
|
457 |
|
458 |
async def node_degree(self, node_id: str) -> int:
|
459 |
"""根据节点id获取节点的度"""
|
460 |
+
SQL = SQL_TEMPLATES["node_degree"]
|
461 |
+
params = {
|
462 |
+
"workspace":self.db.workspace,
|
463 |
+
"node_id":node_id
|
464 |
+
}
|
465 |
# print(SQL)
|
466 |
+
res = await self.db.query(SQL,params)
|
467 |
if res:
|
468 |
# print("Node degree",res["degree"])
|
469 |
return res["degree"]
|
|
|
479 |
|
480 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
481 |
"""根据节点id获取节点数据"""
|
482 |
+
SQL = SQL_TEMPLATES["get_node"]
|
483 |
+
params = {
|
484 |
+
"workspace":self.db.workspace,
|
485 |
+
"node_id":node_id
|
486 |
+
}
|
487 |
# print(self.db.workspace, node_id)
|
488 |
# print(SQL)
|
489 |
+
res = await self.db.query(SQL,params)
|
490 |
if res:
|
491 |
# print("Get node!",self.db.workspace, node_id,res)
|
492 |
return res
|
|
|
498 |
self, source_node_id: str, target_node_id: str
|
499 |
) -> Union[dict, None]:
|
500 |
"""根据源和目标节点id获取边"""
|
501 |
+
SQL = SQL_TEMPLATES["get_edge"]
|
502 |
+
params = {
|
503 |
+
"workspace":self.db.workspace,
|
504 |
+
"source_node_id":source_node_id,
|
505 |
+
"target_node_id":target_node_id
|
506 |
+
}
|
507 |
+
res = await self.db.query(SQL,params)
|
508 |
if res:
|
509 |
# print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
|
510 |
return res
|
|
|
515 |
async def get_node_edges(self, source_node_id: str):
|
516 |
"""根据节点id获取节点的所有边"""
|
517 |
if await self.has_node(source_node_id):
|
518 |
+
SQL = SQL_TEMPLATES["get_node_edges"]
|
519 |
+
params = {
|
520 |
+
"workspace":self.db.workspace,
|
521 |
+
"source_node_id":source_node_id
|
522 |
+
}
|
523 |
+
res = await self.db.query(sql=SQL, params=params, multirows=True)
|
524 |
if res:
|
525 |
data = [(i["source_name"], i["target_name"]) for i in res]
|
526 |
# print("Get node edge!",self.db.workspace, source_node_id,data)
|
|
|
528 |
else:
|
529 |
# print("Node Edge not exist!",self.db.workspace, source_node_id)
|
530 |
return []
|
531 |
+
|
532 |
+
async def get_all_nodes(self, limit: int):
|
533 |
+
"""查询所有节点"""
|
534 |
+
SQL = SQL_TEMPLATES["get_all_nodes"]
|
535 |
+
params = {"workspace":self.db.workspace, "limit":str(limit)}
|
536 |
+
res = await self.db.query(sql=SQL,params=params, multirows=True)
|
537 |
+
if res:
|
538 |
+
return res
|
539 |
|
540 |
+
async def get_all_edges(self, limit: int):
|
541 |
+
"""查询所有边"""
|
542 |
+
SQL = SQL_TEMPLATES["get_all_edges"]
|
543 |
+
params = {"workspace":self.db.workspace, "limit":str(limit)}
|
544 |
+
res = await self.db.query(sql=SQL,params=params, multirows=True)
|
545 |
+
if res:
|
546 |
+
return res
|
547 |
N_T = {
|
548 |
"full_docs": "LIGHTRAG_DOC_FULL",
|
549 |
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
|
|
637 |
|
638 |
SQL_TEMPLATES = {
|
639 |
# SQL for KVStorage
|
640 |
+
"get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
|
641 |
+
"get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
|
642 |
+
"get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})",
|
643 |
+
"get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
|
644 |
+
"filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
|
645 |
"merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
|
646 |
USING DUAL
|
647 |
+
ON (a.id = :check_id)
|
648 |
WHEN NOT MATCHED THEN
|
649 |
+
INSERT(id,content,workspace) values(:id,:content,:workspace)
|
650 |
""",
|
651 |
"merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
|
652 |
USING DUAL
|
653 |
+
ON (a.id = :check_id)
|
654 |
WHEN NOT MATCHED THEN
|
655 |
INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
|
656 |
+
values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """,
|
657 |
# SQL for VectorStorage
|
658 |
"entities": """SELECT name as entity_name FROM
|
659 |
+
(SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
660 |
+
FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace)
|
661 |
+
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
662 |
"relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
|
663 |
+
(SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
664 |
+
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace)
|
665 |
+
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
666 |
"chunks": """SELECT id FROM
|
667 |
+
(SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
|
668 |
+
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace)
|
669 |
+
WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
|
670 |
# SQL for GraphStorage
|
671 |
"has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
672 |
MATCH (a)
|
673 |
+
WHERE a.workspace=:workspace AND a.name=:node_id
|
674 |
COLUMNS (a.name))""",
|
675 |
"has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
|
676 |
MATCH (a) -[e]-> (b)
|
677 |
+
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
678 |
+
AND a.name=:source_node_id AND b.name=:target_node_id
|
679 |
COLUMNS (e.source_name,e.target_name) )""",
|
680 |
"node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
|
681 |
MATCH (a)-[e]->(b)
|
682 |
+
WHERE a.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
683 |
+
AND a.name=:node_id or b.name = :node_id
|
684 |
COLUMNS (a.name))""",
|
685 |
"get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
|
686 |
FROM GRAPH_TABLE (lightrag_graph
|
687 |
MATCH (a)
|
688 |
+
WHERE a.workspace=:workspace AND a.name=:node_id
|
689 |
COLUMNS (a.name)
|
690 |
) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
|
691 |
+
WHERE t2.workspace=:workspace""",
|
692 |
"get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
|
693 |
NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
|
694 |
FROM GRAPH_TABLE (lightrag_graph
|
695 |
MATCH (a)-[e]->(b)
|
696 |
+
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
697 |
+
AND a.name=:source_node_id and b.name = :target_node_id
|
698 |
COLUMNS (e.id,a.name as source_id)
|
699 |
) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
|
700 |
"get_node_edges": """SELECT source_name,target_name
|
701 |
FROM GRAPH_TABLE (lightrag_graph
|
702 |
MATCH (a)-[e]->(b)
|
703 |
+
WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
|
704 |
+
AND a.name=:source_node_id
|
705 |
COLUMNS (a.name as source_name,b.name as target_name))""",
|
706 |
"merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
|
707 |
USING DUAL
|
708 |
+
ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id)
|
709 |
WHEN NOT MATCHED THEN
|
710 |
INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
|
711 |
+
values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """,
|
712 |
"merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
|
713 |
USING DUAL
|
714 |
+
ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
|
715 |
WHEN NOT MATCHED THEN
|
716 |
INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
|
717 |
+
values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
|
718 |
+
"get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content
|
719 |
+
FROM LIGHTRAG_GRAPH_NODES t1
|
720 |
+
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
721 |
+
WHERE t1.workspace=:workspace
|
722 |
+
order by t1.CREATETIME DESC
|
723 |
+
fetch first :limit rows only
|
724 |
+
""",
|
725 |
+
"get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
|
726 |
+
t1.weight,t1.DESCRIPTION,t2.content
|
727 |
+
FROM LIGHTRAG_GRAPH_EDGES t1
|
728 |
+
LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
|
729 |
+
WHERE t1.workspace=:workspace
|
730 |
+
order by t1.CREATETIME DESC
|
731 |
+
fetch first :limit rows only"""
|
732 |
}
|
lightrag/operate.py
CHANGED
@@ -405,12 +405,13 @@ async def local_query(
|
|
405 |
kw_prompt = kw_prompt_temp.format(query=query)
|
406 |
result = await use_model_func(kw_prompt)
|
407 |
json_text = locate_json_string_body_from_string(result)
|
408 |
-
|
409 |
try:
|
410 |
keywords_data = json.loads(json_text)
|
411 |
keywords = keywords_data.get("low_level_keywords", [])
|
412 |
keywords = ", ".join(keywords)
|
413 |
except json.JSONDecodeError:
|
|
|
414 |
try:
|
415 |
result = (
|
416 |
result.replace(kw_prompt[:-1], "")
|
@@ -443,6 +444,8 @@ async def local_query(
|
|
443 |
sys_prompt = sys_prompt_temp.format(
|
444 |
context_data=context, response_type=query_param.response_type
|
445 |
)
|
|
|
|
|
446 |
response = await use_model_func(
|
447 |
query,
|
448 |
system_prompt=sys_prompt,
|
@@ -672,12 +675,12 @@ async def global_query(
|
|
672 |
kw_prompt = kw_prompt_temp.format(query=query)
|
673 |
result = await use_model_func(kw_prompt)
|
674 |
json_text = locate_json_string_body_from_string(result)
|
675 |
-
|
676 |
try:
|
677 |
keywords_data = json.loads(json_text)
|
678 |
keywords = keywords_data.get("high_level_keywords", [])
|
679 |
keywords = ", ".join(keywords)
|
680 |
-
except json.JSONDecodeError:
|
681 |
try:
|
682 |
result = (
|
683 |
result.replace(kw_prompt[:-1], "")
|
@@ -714,6 +717,8 @@ async def global_query(
|
|
714 |
sys_prompt = sys_prompt_temp.format(
|
715 |
context_data=context, response_type=query_param.response_type
|
716 |
)
|
|
|
|
|
717 |
response = await use_model_func(
|
718 |
query,
|
719 |
system_prompt=sys_prompt,
|
@@ -914,6 +919,7 @@ async def hybrid_query(
|
|
914 |
|
915 |
result = await use_model_func(kw_prompt)
|
916 |
json_text = locate_json_string_body_from_string(result)
|
|
|
917 |
try:
|
918 |
keywords_data = json.loads(json_text)
|
919 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
@@ -969,6 +975,8 @@ async def hybrid_query(
|
|
969 |
sys_prompt = sys_prompt_temp.format(
|
970 |
context_data=context, response_type=query_param.response_type
|
971 |
)
|
|
|
|
|
972 |
response = await use_model_func(
|
973 |
query,
|
974 |
system_prompt=sys_prompt,
|
@@ -1079,6 +1087,8 @@ async def naive_query(
|
|
1079 |
sys_prompt = sys_prompt_temp.format(
|
1080 |
content_data=section, response_type=query_param.response_type
|
1081 |
)
|
|
|
|
|
1082 |
response = await use_model_func(
|
1083 |
query,
|
1084 |
system_prompt=sys_prompt,
|
|
|
405 |
kw_prompt = kw_prompt_temp.format(query=query)
|
406 |
result = await use_model_func(kw_prompt)
|
407 |
json_text = locate_json_string_body_from_string(result)
|
408 |
+
logger.debug("local_query json_text:", json_text)
|
409 |
try:
|
410 |
keywords_data = json.loads(json_text)
|
411 |
keywords = keywords_data.get("low_level_keywords", [])
|
412 |
keywords = ", ".join(keywords)
|
413 |
except json.JSONDecodeError:
|
414 |
+
print(result)
|
415 |
try:
|
416 |
result = (
|
417 |
result.replace(kw_prompt[:-1], "")
|
|
|
444 |
sys_prompt = sys_prompt_temp.format(
|
445 |
context_data=context, response_type=query_param.response_type
|
446 |
)
|
447 |
+
if query_param.only_need_prompt:
|
448 |
+
return sys_prompt
|
449 |
response = await use_model_func(
|
450 |
query,
|
451 |
system_prompt=sys_prompt,
|
|
|
675 |
kw_prompt = kw_prompt_temp.format(query=query)
|
676 |
result = await use_model_func(kw_prompt)
|
677 |
json_text = locate_json_string_body_from_string(result)
|
678 |
+
logger.debug("global json_text:", json_text)
|
679 |
try:
|
680 |
keywords_data = json.loads(json_text)
|
681 |
keywords = keywords_data.get("high_level_keywords", [])
|
682 |
keywords = ", ".join(keywords)
|
683 |
+
except json.JSONDecodeError:
|
684 |
try:
|
685 |
result = (
|
686 |
result.replace(kw_prompt[:-1], "")
|
|
|
717 |
sys_prompt = sys_prompt_temp.format(
|
718 |
context_data=context, response_type=query_param.response_type
|
719 |
)
|
720 |
+
if query_param.only_need_prompt:
|
721 |
+
return sys_prompt
|
722 |
response = await use_model_func(
|
723 |
query,
|
724 |
system_prompt=sys_prompt,
|
|
|
919 |
|
920 |
result = await use_model_func(kw_prompt)
|
921 |
json_text = locate_json_string_body_from_string(result)
|
922 |
+
logger.debug("hybrid_query json_text:", json_text)
|
923 |
try:
|
924 |
keywords_data = json.loads(json_text)
|
925 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
|
|
975 |
sys_prompt = sys_prompt_temp.format(
|
976 |
context_data=context, response_type=query_param.response_type
|
977 |
)
|
978 |
+
if query_param.only_need_prompt:
|
979 |
+
return sys_prompt
|
980 |
response = await use_model_func(
|
981 |
query,
|
982 |
system_prompt=sys_prompt,
|
|
|
1087 |
sys_prompt = sys_prompt_temp.format(
|
1088 |
content_data=section, response_type=query_param.response_type
|
1089 |
)
|
1090 |
+
if query_param.only_need_prompt:
|
1091 |
+
return sys_prompt
|
1092 |
response = await use_model_func(
|
1093 |
query,
|
1094 |
system_prompt=sys_prompt,
|
lightrag/utils.py
CHANGED
@@ -49,7 +49,11 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
|
49 |
"""Locate the JSON string body from a string"""
|
50 |
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
51 |
if maybe_json_str is not None:
|
52 |
-
|
|
|
|
|
|
|
|
|
53 |
else:
|
54 |
return None
|
55 |
|
|
|
49 |
"""Locate the JSON string body from a string"""
|
50 |
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
51 |
if maybe_json_str is not None:
|
52 |
+
maybe_json_str = maybe_json_str.group(0)
|
53 |
+
maybe_json_str = maybe_json_str.replace("\\n", "")
|
54 |
+
maybe_json_str = maybe_json_str.replace("\n", "")
|
55 |
+
maybe_json_str = maybe_json_str.replace("'", '"')
|
56 |
+
return maybe_json_str
|
57 |
else:
|
58 |
return None
|
59 |
|