jin commited on
Commit
c77f948
·
1 Parent(s): 5dcb28f

use oracle bind variables to avoid error

Browse files
lightrag/base.py CHANGED
@@ -17,6 +17,7 @@ T = TypeVar("T")
17
  class QueryParam:
18
  mode: Literal["local", "global", "hybrid", "naive"] = "global"
19
  only_need_context: bool = False
 
20
  response_type: str = "Multiple Paragraphs"
21
  # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
22
  top_k: int = 60
 
17
  class QueryParam:
18
  mode: Literal["local", "global", "hybrid", "naive"] = "global"
19
  only_need_context: bool = False
20
+ only_need_prompt: bool = False
21
  response_type: str = "Multiple Paragraphs"
22
  # Number of top-k items to retrieve; corresponds to entities in "local" mode and relationships in "global" mode.
23
  top_k: int = 60
lightrag/kg/oracle_impl.py CHANGED
@@ -114,16 +114,17 @@ class OracleDB:
114
 
115
  logger.info("Finished check all tables in Oracle database")
116
 
117
- async def query(self, sql: str, multirows: bool = False) -> Union[dict, None]:
118
  async with self.pool.acquire() as connection:
119
  connection.inputtypehandler = self.input_type_handler
120
  connection.outputtypehandler = self.output_type_handler
121
  with connection.cursor() as cursor:
122
  try:
123
- await cursor.execute(sql)
124
  except Exception as e:
125
  logger.error(f"Oracle database error: {e}")
126
  print(sql)
 
127
  raise
128
  columns = [column[0].lower() for column in cursor.description]
129
  if multirows:
@@ -140,7 +141,7 @@ class OracleDB:
140
  data = None
141
  return data
142
 
143
- async def execute(self, sql: str, data: list = None):
144
  # logger.info("go into OracleDB execute method")
145
  try:
146
  async with self.pool.acquire() as connection:
@@ -172,11 +173,10 @@ class OracleKVStorage(BaseKVStorage):
172
 
173
  async def get_by_id(self, id: str) -> Union[dict, None]:
174
  """根据 id 获取 doc_full 数据."""
175
- SQL = SQL_TEMPLATES["get_by_id_" + self.namespace].format(
176
- workspace=self.db.workspace, id=id
177
- )
178
  # print("get_by_id:"+SQL)
179
- res = await self.db.query(SQL)
180
  if res:
181
  data = res # {"data":res}
182
  # print (data)
@@ -187,11 +187,11 @@ class OracleKVStorage(BaseKVStorage):
187
  # Query by id
188
  async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
189
  """根据 id 获取 doc_chunks 数据"""
190
- SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
191
- workspace=self.db.workspace, ids=",".join([f"'{id}'" for id in ids])
192
- )
193
- # print("get_by_ids:"+SQL)
194
- res = await self.db.query(SQL, multirows=True)
195
  if res:
196
  data = res # [{"data":i} for i in res]
197
  # print(data)
@@ -201,12 +201,16 @@ class OracleKVStorage(BaseKVStorage):
201
 
202
  async def filter_keys(self, keys: list[str]) -> set[str]:
203
  """过滤掉重复内容"""
204
- SQL = SQL_TEMPLATES["filter_keys"].format(
205
- table_name=N_T[self.namespace],
206
- workspace=self.db.workspace,
207
- ids=",".join([f"'{k}'" for k in keys]),
208
- )
209
- res = await self.db.query(SQL, multirows=True)
 
 
 
 
210
  data = None
211
  if res:
212
  exist_keys = [key["id"] for key in res]
@@ -243,29 +247,31 @@ class OracleKVStorage(BaseKVStorage):
243
  d["__vector__"] = embeddings[i]
244
  # print(list_data)
245
  for item in list_data:
246
- merge_sql = SQL_TEMPLATES["merge_chunk"].format(check_id=item["__id__"])
247
-
248
- values = [
249
- item["__id__"],
250
- item["content"],
251
- self.db.workspace,
252
- item["tokens"],
253
- item["chunk_order_index"],
254
- item["full_doc_id"],
255
- item["__vector__"],
256
- ]
257
  # print(merge_sql)
258
- await self.db.execute(merge_sql, values)
259
 
260
  if self.namespace == "full_docs":
261
  for k, v in self._data.items():
262
  # values.clear()
263
- merge_sql = SQL_TEMPLATES["merge_doc_full"].format(
264
- check_id=k,
265
- )
266
- values = [k, self._data[k]["content"], self.db.workspace]
 
 
 
267
  # print(merge_sql)
268
- await self.db.execute(merge_sql, values)
269
  return left_data
270
 
271
  async def index_done_callback(self):
@@ -295,18 +301,17 @@ class OracleVectorDBStorage(BaseVectorStorage):
295
  # 转换精度
296
  dtype = str(embedding.dtype).upper()
297
  dimension = embedding.shape[0]
298
- embedding_string = ", ".join(map(str, embedding.tolist()))
299
-
300
- SQL = SQL_TEMPLATES[self.namespace].format(
301
- embedding_string=embedding_string,
302
- dimension=dimension,
303
- dtype=dtype,
304
- workspace=self.db.workspace,
305
- top_k=top_k,
306
- better_than_threshold=self.cosine_better_than_threshold,
307
- )
308
  # print(SQL)
309
- results = await self.db.query(SQL, multirows=True)
310
  # print("vector search result:",results)
311
  return results
312
 
@@ -339,22 +344,18 @@ class OracleGraphStorage(BaseGraphStorage):
339
  )
340
  embeddings = np.concatenate(embeddings_list)
341
  content_vector = embeddings[0]
342
- merge_sql = SQL_TEMPLATES["merge_node"].format(
343
- workspace=self.db.workspace, name=entity_name, source_chunk_id=source_id
344
- )
 
 
 
 
 
 
 
345
  # print(merge_sql)
346
- await self.db.execute(
347
- merge_sql,
348
- [
349
- self.db.workspace,
350
- entity_name,
351
- entity_type,
352
- description,
353
- source_id,
354
- content,
355
- content_vector,
356
- ],
357
- )
358
  # self._graph.add_node(node_id, **node_data)
359
 
360
  async def upsert_edge(
@@ -379,27 +380,20 @@ class OracleGraphStorage(BaseGraphStorage):
379
  )
380
  embeddings = np.concatenate(embeddings_list)
381
  content_vector = embeddings[0]
382
- merge_sql = SQL_TEMPLATES["merge_edge"].format(
383
- workspace=self.db.workspace,
384
- source_name=source_name,
385
- target_name=target_name,
386
- source_chunk_id=source_chunk_id,
387
- )
 
 
 
 
 
 
388
  # print(merge_sql)
389
- await self.db.execute(
390
- merge_sql,
391
- [
392
- self.db.workspace,
393
- source_name,
394
- target_name,
395
- weight,
396
- keywords,
397
- description,
398
- source_chunk_id,
399
- content,
400
- content_vector,
401
- ],
402
- )
403
  # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
404
 
405
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
@@ -429,12 +423,14 @@ class OracleGraphStorage(BaseGraphStorage):
429
  #################### query method #################
430
  async def has_node(self, node_id: str) -> bool:
431
  """根据节点id检查节点是否存在"""
432
- SQL = SQL_TEMPLATES["has_node"].format(
433
- workspace=self.db.workspace, node_id=node_id
434
- )
 
 
435
  # print(SQL)
436
  # print(self.db.workspace, node_id)
437
- res = await self.db.query(SQL)
438
  if res:
439
  # print("Node exist!",res)
440
  return True
@@ -444,13 +440,14 @@ class OracleGraphStorage(BaseGraphStorage):
444
 
445
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
446
  """根据源和目标节点id检查边是否存在"""
447
- SQL = SQL_TEMPLATES["has_edge"].format(
448
- workspace=self.db.workspace,
449
- source_node_id=source_node_id,
450
- target_node_id=target_node_id,
451
- )
 
452
  # print(SQL)
453
- res = await self.db.query(SQL)
454
  if res:
455
  # print("Edge exist!",res)
456
  return True
@@ -460,11 +457,13 @@ class OracleGraphStorage(BaseGraphStorage):
460
 
461
  async def node_degree(self, node_id: str) -> int:
462
  """根据节点id获取节点的度"""
463
- SQL = SQL_TEMPLATES["node_degree"].format(
464
- workspace=self.db.workspace, node_id=node_id
465
- )
 
 
466
  # print(SQL)
467
- res = await self.db.query(SQL)
468
  if res:
469
  # print("Node degree",res["degree"])
470
  return res["degree"]
@@ -480,12 +479,14 @@ class OracleGraphStorage(BaseGraphStorage):
480
 
481
  async def get_node(self, node_id: str) -> Union[dict, None]:
482
  """根据节点id获取节点数据"""
483
- SQL = SQL_TEMPLATES["get_node"].format(
484
- workspace=self.db.workspace, node_id=node_id
485
- )
 
 
486
  # print(self.db.workspace, node_id)
487
  # print(SQL)
488
- res = await self.db.query(SQL)
489
  if res:
490
  # print("Get node!",self.db.workspace, node_id,res)
491
  return res
@@ -497,12 +498,13 @@ class OracleGraphStorage(BaseGraphStorage):
497
  self, source_node_id: str, target_node_id: str
498
  ) -> Union[dict, None]:
499
  """根据源和目标节点id获取边"""
500
- SQL = SQL_TEMPLATES["get_edge"].format(
501
- workspace=self.db.workspace,
502
- source_node_id=source_node_id,
503
- target_node_id=target_node_id,
504
- )
505
- res = await self.db.query(SQL)
 
506
  if res:
507
  # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
508
  return res
@@ -513,10 +515,12 @@ class OracleGraphStorage(BaseGraphStorage):
513
  async def get_node_edges(self, source_node_id: str):
514
  """根据节点id获取节点的所有边"""
515
  if await self.has_node(source_node_id):
516
- SQL = SQL_TEMPLATES["get_node_edges"].format(
517
- workspace=self.db.workspace, source_node_id=source_node_id
518
- )
519
- res = await self.db.query(sql=SQL, multirows=True)
 
 
520
  if res:
521
  data = [(i["source_name"], i["target_name"]) for i in res]
522
  # print("Get node edge!",self.db.workspace, source_node_id,data)
@@ -524,8 +528,22 @@ class OracleGraphStorage(BaseGraphStorage):
524
  else:
525
  # print("Node Edge not exist!",self.db.workspace, source_node_id)
526
  return []
 
 
 
 
 
 
 
 
527
 
528
-
 
 
 
 
 
 
529
  N_T = {
530
  "full_docs": "LIGHTRAG_DOC_FULL",
531
  "text_chunks": "LIGHTRAG_DOC_CHUNKS",
@@ -619,82 +637,96 @@ TABLES = {
619
 
620
  SQL_TEMPLATES = {
621
  # SQL for KVStorage
622
- "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID='{id}'",
623
- "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID='{id}'",
624
- "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace='{workspace}' and ID in ({ids})",
625
- "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace='{workspace}' and ID in ({ids})",
626
- "filter_keys": "select id from {table_name} where workspace='{workspace}' and id in ({ids})",
627
  "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
628
  USING DUAL
629
- ON (a.id = '{check_id}')
630
  WHEN NOT MATCHED THEN
631
- INSERT(id,content,workspace) values(:1,:2,:3)
632
  """,
633
  "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
634
  USING DUAL
635
- ON (a.id = '{check_id}')
636
  WHEN NOT MATCHED THEN
637
  INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
638
- values (:1,:2,:3,:4,:5,:6,:7) """,
639
  # SQL for VectorStorage
640
  "entities": """SELECT name as entity_name FROM
641
- (SELECT id,name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
642
- FROM LIGHTRAG_GRAPH_NODES WHERE workspace='{workspace}')
643
- WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
644
  "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
645
- (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
646
- FROM LIGHTRAG_GRAPH_EDGES WHERE workspace='{workspace}')
647
- WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
648
  "chunks": """SELECT id FROM
649
- (SELECT id,VECTOR_DISTANCE(content_vector,vector('[{embedding_string}]',{dimension},{dtype}),COSINE) as distance
650
- FROM LIGHTRAG_DOC_CHUNKS WHERE workspace='{workspace}')
651
- WHERE distance>{better_than_threshold} ORDER BY distance ASC FETCH FIRST {top_k} ROWS ONLY""",
652
  # SQL for GraphStorage
653
  "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
654
  MATCH (a)
655
- WHERE a.workspace='{workspace}' AND a.name='{node_id}'
656
  COLUMNS (a.name))""",
657
  "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
658
  MATCH (a) -[e]-> (b)
659
- WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
660
- AND a.name='{source_node_id}' AND b.name='{target_node_id}'
661
  COLUMNS (e.source_name,e.target_name) )""",
662
  "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
663
  MATCH (a)-[e]->(b)
664
- WHERE a.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
665
- AND a.name='{node_id}' or b.name = '{node_id}'
666
  COLUMNS (a.name))""",
667
  "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
668
  FROM GRAPH_TABLE (lightrag_graph
669
  MATCH (a)
670
- WHERE a.workspace='{workspace}' AND a.name='{node_id}'
671
  COLUMNS (a.name)
672
  ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
673
- WHERE t2.workspace='{workspace}'""",
674
  "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
675
  NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
676
  FROM GRAPH_TABLE (lightrag_graph
677
  MATCH (a)-[e]->(b)
678
- WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
679
- AND a.name='{source_node_id}' and b.name = '{target_node_id}'
680
  COLUMNS (e.id,a.name as source_id)
681
  ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
682
  "get_node_edges": """SELECT source_name,target_name
683
  FROM GRAPH_TABLE (lightrag_graph
684
  MATCH (a)-[e]->(b)
685
- WHERE e.workspace='{workspace}' and a.workspace='{workspace}' and b.workspace='{workspace}'
686
- AND a.name='{source_node_id}'
687
  COLUMNS (a.name as source_name,b.name as target_name))""",
688
  "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
689
  USING DUAL
690
- ON (a.workspace = '{workspace}' and a.name='{name}' and a.source_chunk_id='{source_chunk_id}')
691
  WHEN NOT MATCHED THEN
692
  INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
693
- values (:1,:2,:3,:4,:5,:6,:7) """,
694
  "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
695
  USING DUAL
696
- ON (a.workspace = '{workspace}' and a.source_name='{source_name}' and a.target_name='{target_name}' and a.source_chunk_id='{source_chunk_id}')
697
  WHEN NOT MATCHED THEN
698
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
699
- values (:1,:2,:3,:4,:5,:6,:7,:8,:9) """,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
700
  }
 
114
 
115
  logger.info("Finished check all tables in Oracle database")
116
 
117
+ async def query(self, sql: str, params: dict = None, multirows: bool = False) -> Union[dict, None]:
118
  async with self.pool.acquire() as connection:
119
  connection.inputtypehandler = self.input_type_handler
120
  connection.outputtypehandler = self.output_type_handler
121
  with connection.cursor() as cursor:
122
  try:
123
+ await cursor.execute(sql, params)
124
  except Exception as e:
125
  logger.error(f"Oracle database error: {e}")
126
  print(sql)
127
+ print(params)
128
  raise
129
  columns = [column[0].lower() for column in cursor.description]
130
  if multirows:
 
141
  data = None
142
  return data
143
 
144
+ async def execute(self, sql: str, data: list | dict = None):
145
  # logger.info("go into OracleDB execute method")
146
  try:
147
  async with self.pool.acquire() as connection:
 
173
 
174
  async def get_by_id(self, id: str) -> Union[dict, None]:
175
  """根据 id 获取 doc_full 数据."""
176
+ SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
177
+ params = {"workspace":self.db.workspace, "id":id}
 
178
  # print("get_by_id:"+SQL)
179
+ res = await self.db.query(SQL,params)
180
  if res:
181
  data = res # {"data":res}
182
  # print (data)
 
187
  # Query by id
188
  async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
189
  """根据 id 获取 doc_chunks 数据"""
190
+ SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(ids=",".join([f"'{id}'" for id in ids]))
191
+ params = {"workspace":self.db.workspace}
192
+ #print("get_by_ids:"+SQL)
193
+ #print(params)
194
+ res = await self.db.query(SQL,params, multirows=True)
195
  if res:
196
  data = res # [{"data":i} for i in res]
197
  # print(data)
 
201
 
202
  async def filter_keys(self, keys: list[str]) -> set[str]:
203
  """过滤掉重复内容"""
204
+ SQL = SQL_TEMPLATES["filter_keys"].format(table_name=N_T[self.namespace],
205
+ ids=",".join([f"'{id}'" for id in keys]))
206
+ params = {"workspace":self.db.workspace}
207
+ try:
208
+ await self.db.query(SQL, params)
209
+ except Exception as e:
210
+ logger.error(f"Oracle database error: {e}")
211
+ print(SQL)
212
+ print(params)
213
+ res = await self.db.query(SQL, params,multirows=True)
214
  data = None
215
  if res:
216
  exist_keys = [key["id"] for key in res]
 
247
  d["__vector__"] = embeddings[i]
248
  # print(list_data)
249
  for item in list_data:
250
+ merge_sql = SQL_TEMPLATES["merge_chunk"]
251
+ data = {"check_id":item["__id__"],
252
+ "id":item["__id__"],
253
+ "content":item["content"],
254
+ "workspace":self.db.workspace,
255
+ "tokens":item["tokens"],
256
+ "chunk_order_index":item["chunk_order_index"],
257
+ "full_doc_id":item["full_doc_id"],
258
+ "content_vector":item["__vector__"]
259
+ }
 
260
  # print(merge_sql)
261
+ await self.db.execute(merge_sql, data)
262
 
263
  if self.namespace == "full_docs":
264
  for k, v in self._data.items():
265
  # values.clear()
266
+ merge_sql = SQL_TEMPLATES["merge_doc_full"]
267
+ data = {
268
+ "check_id":k,
269
+ "id":k,
270
+ "content":v["content"],
271
+ "workspace":self.db.workspace
272
+ }
273
  # print(merge_sql)
274
+ await self.db.execute(merge_sql, data)
275
  return left_data
276
 
277
  async def index_done_callback(self):
 
301
  # 转换精度
302
  dtype = str(embedding.dtype).upper()
303
  dimension = embedding.shape[0]
304
+ embedding_string = "["+", ".join(map(str, embedding.tolist()))+"]"
305
+
306
+ SQL = SQL_TEMPLATES[self.namespace].format(dimension=dimension, dtype=dtype)
307
+ params = {
308
+ "embedding_string": embedding_string,
309
+ "workspace": self.db.workspace,
310
+ "top_k": top_k,
311
+ "better_than_threshold": self.cosine_better_than_threshold,
312
+ }
 
313
  # print(SQL)
314
+ results = await self.db.query(SQL,params=params, multirows=True)
315
  # print("vector search result:",results)
316
  return results
317
 
 
344
  )
345
  embeddings = np.concatenate(embeddings_list)
346
  content_vector = embeddings[0]
347
+ merge_sql = SQL_TEMPLATES["merge_node"]
348
+ data = {
349
+ "workspace":self.db.workspace,
350
+ "name":entity_name,
351
+ "entity_type":entity_type,
352
+ "description":description,
353
+ "source_chunk_id":source_id,
354
+ "content":content,
355
+ "content_vector":content_vector
356
+ }
357
  # print(merge_sql)
358
+ await self.db.execute(merge_sql,data)
 
 
 
 
 
 
 
 
 
 
 
359
  # self._graph.add_node(node_id, **node_data)
360
 
361
  async def upsert_edge(
 
380
  )
381
  embeddings = np.concatenate(embeddings_list)
382
  content_vector = embeddings[0]
383
+ merge_sql = SQL_TEMPLATES["merge_edge"]
384
+ data = {
385
+ "workspace":self.db.workspace,
386
+ "source_name":source_name,
387
+ "target_name":target_name,
388
+ "weight":weight,
389
+ "keywords":keywords,
390
+ "description":description,
391
+ "source_chunk_id":source_chunk_id,
392
+ "content":content,
393
+ "content_vector":content_vector
394
+ }
395
  # print(merge_sql)
396
+ await self.db.execute(merge_sql,data)
 
 
 
 
 
 
 
 
 
 
 
 
 
397
  # self._graph.add_edge(source_node_id, target_node_id, **edge_data)
398
 
399
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
 
423
  #################### query method #################
424
  async def has_node(self, node_id: str) -> bool:
425
  """根据节点id检查节点是否存在"""
426
+ SQL = SQL_TEMPLATES["has_node"]
427
+ params = {
428
+ "workspace":self.db.workspace,
429
+ "node_id":node_id
430
+ }
431
  # print(SQL)
432
  # print(self.db.workspace, node_id)
433
+ res = await self.db.query(SQL,params)
434
  if res:
435
  # print("Node exist!",res)
436
  return True
 
440
 
441
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
442
  """根据源和目标节点id检查边是否存在"""
443
+ SQL = SQL_TEMPLATES["has_edge"]
444
+ params = {
445
+ "workspace":self.db.workspace,
446
+ "source_node_id":source_node_id,
447
+ "target_node_id":target_node_id
448
+ }
449
  # print(SQL)
450
+ res = await self.db.query(SQL,params)
451
  if res:
452
  # print("Edge exist!",res)
453
  return True
 
457
 
458
  async def node_degree(self, node_id: str) -> int:
459
  """根据节点id获取节点的度"""
460
+ SQL = SQL_TEMPLATES["node_degree"]
461
+ params = {
462
+ "workspace":self.db.workspace,
463
+ "node_id":node_id
464
+ }
465
  # print(SQL)
466
+ res = await self.db.query(SQL,params)
467
  if res:
468
  # print("Node degree",res["degree"])
469
  return res["degree"]
 
479
 
480
  async def get_node(self, node_id: str) -> Union[dict, None]:
481
  """根据节点id获取节点数据"""
482
+ SQL = SQL_TEMPLATES["get_node"]
483
+ params = {
484
+ "workspace":self.db.workspace,
485
+ "node_id":node_id
486
+ }
487
  # print(self.db.workspace, node_id)
488
  # print(SQL)
489
+ res = await self.db.query(SQL,params)
490
  if res:
491
  # print("Get node!",self.db.workspace, node_id,res)
492
  return res
 
498
  self, source_node_id: str, target_node_id: str
499
  ) -> Union[dict, None]:
500
  """根据源和目标节点id获取边"""
501
+ SQL = SQL_TEMPLATES["get_edge"]
502
+ params = {
503
+ "workspace":self.db.workspace,
504
+ "source_node_id":source_node_id,
505
+ "target_node_id":target_node_id
506
+ }
507
+ res = await self.db.query(SQL,params)
508
  if res:
509
  # print("Get edge!",self.db.workspace, source_node_id, target_node_id,res[0])
510
  return res
 
515
  async def get_node_edges(self, source_node_id: str):
516
  """根据节点id获取节点的所有边"""
517
  if await self.has_node(source_node_id):
518
+ SQL = SQL_TEMPLATES["get_node_edges"]
519
+ params = {
520
+ "workspace":self.db.workspace,
521
+ "source_node_id":source_node_id
522
+ }
523
+ res = await self.db.query(sql=SQL, params=params, multirows=True)
524
  if res:
525
  data = [(i["source_name"], i["target_name"]) for i in res]
526
  # print("Get node edge!",self.db.workspace, source_node_id,data)
 
528
  else:
529
  # print("Node Edge not exist!",self.db.workspace, source_node_id)
530
  return []
531
+
532
+ async def get_all_nodes(self, limit: int):
533
+ """查询所有节点"""
534
+ SQL = SQL_TEMPLATES["get_all_nodes"]
535
+ params = {"workspace":self.db.workspace, "limit":str(limit)}
536
+ res = await self.db.query(sql=SQL,params=params, multirows=True)
537
+ if res:
538
+ return res
539
 
540
+ async def get_all_edges(self, limit: int):
541
+ """查询所有边"""
542
+ SQL = SQL_TEMPLATES["get_all_edges"]
543
+ params = {"workspace":self.db.workspace, "limit":str(limit)}
544
+ res = await self.db.query(sql=SQL,params=params, multirows=True)
545
+ if res:
546
+ return res
547
  N_T = {
548
  "full_docs": "LIGHTRAG_DOC_FULL",
549
  "text_chunks": "LIGHTRAG_DOC_CHUNKS",
 
637
 
638
  SQL_TEMPLATES = {
639
  # SQL for KVStorage
640
+ "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
641
+ "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
642
+ "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})",
643
+ "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
644
+ "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
645
  "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
646
  USING DUAL
647
+ ON (a.id = :check_id)
648
  WHEN NOT MATCHED THEN
649
+ INSERT(id,content,workspace) values(:id,:content,:workspace)
650
  """,
651
  "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
652
  USING DUAL
653
+ ON (a.id = :check_id)
654
  WHEN NOT MATCHED THEN
655
  INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
656
+ values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """,
657
  # SQL for VectorStorage
658
  "entities": """SELECT name as entity_name FROM
659
+ (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
660
+ FROM LIGHTRAG_GRAPH_NODES WHERE workspace=:workspace)
661
+ WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
662
  "relationships": """SELECT source_name as src_id, target_name as tgt_id FROM
663
+ (SELECT id,source_name,target_name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
664
+ FROM LIGHTRAG_GRAPH_EDGES WHERE workspace=:workspace)
665
+ WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
666
  "chunks": """SELECT id FROM
667
+ (SELECT id,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
668
+ FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=:workspace)
669
+ WHERE distance>:better_than_threshold ORDER BY distance ASC FETCH FIRST :top_k ROWS ONLY""",
670
  # SQL for GraphStorage
671
  "has_node": """SELECT * FROM GRAPH_TABLE (lightrag_graph
672
  MATCH (a)
673
+ WHERE a.workspace=:workspace AND a.name=:node_id
674
  COLUMNS (a.name))""",
675
  "has_edge": """SELECT * FROM GRAPH_TABLE (lightrag_graph
676
  MATCH (a) -[e]-> (b)
677
+ WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
678
+ AND a.name=:source_node_id AND b.name=:target_node_id
679
  COLUMNS (e.source_name,e.target_name) )""",
680
  "node_degree": """SELECT count(1) as degree FROM GRAPH_TABLE (lightrag_graph
681
  MATCH (a)-[e]->(b)
682
+ WHERE a.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
683
+ AND a.name=:node_id or b.name = :node_id
684
  COLUMNS (a.name))""",
685
  "get_node": """SELECT t1.name,t2.entity_type,t2.source_chunk_id as source_id,NVL(t2.description,'') AS description
686
  FROM GRAPH_TABLE (lightrag_graph
687
  MATCH (a)
688
+ WHERE a.workspace=:workspace AND a.name=:node_id
689
  COLUMNS (a.name)
690
  ) t1 JOIN LIGHTRAG_GRAPH_NODES t2 on t1.name=t2.name
691
+ WHERE t2.workspace=:workspace""",
692
  "get_edge": """SELECT t1.source_id,t2.weight,t2.source_chunk_id as source_id,t2.keywords,
693
  NVL(t2.description,'') AS description,NVL(t2.KEYWORDS,'') AS keywords
694
  FROM GRAPH_TABLE (lightrag_graph
695
  MATCH (a)-[e]->(b)
696
+ WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
697
+ AND a.name=:source_node_id and b.name = :target_node_id
698
  COLUMNS (e.id,a.name as source_id)
699
  ) t1 JOIN LIGHTRAG_GRAPH_EDGES t2 on t1.id=t2.id""",
700
  "get_node_edges": """SELECT source_name,target_name
701
  FROM GRAPH_TABLE (lightrag_graph
702
  MATCH (a)-[e]->(b)
703
+ WHERE e.workspace=:workspace and a.workspace=:workspace and b.workspace=:workspace
704
+ AND a.name=:source_node_id
705
  COLUMNS (a.name as source_name,b.name as target_name))""",
706
  "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
707
  USING DUAL
708
+ ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id)
709
  WHEN NOT MATCHED THEN
710
  INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
711
+ values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """,
712
  "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
713
  USING DUAL
714
+ ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
715
  WHEN NOT MATCHED THEN
716
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
717
+ values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
718
+ "get_all_nodes":"""SELECT t1.name as id,t1.entity_type as label,t1.DESCRIPTION,t2.content
719
+ FROM LIGHTRAG_GRAPH_NODES t1
720
+ LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
721
+ WHERE t1.workspace=:workspace
722
+ order by t1.CREATETIME DESC
723
+ fetch first :limit rows only
724
+ """,
725
+ "get_all_edges":"""SELECT t1.id,t1.keywords as label,t1.keywords, t1.source_name as source, t1.target_name as target,
726
+ t1.weight,t1.DESCRIPTION,t2.content
727
+ FROM LIGHTRAG_GRAPH_EDGES t1
728
+ LEFT JOIN LIGHTRAG_DOC_CHUNKS t2 on t1.source_chunk_id=t2.id
729
+ WHERE t1.workspace=:workspace
730
+ order by t1.CREATETIME DESC
731
+ fetch first :limit rows only"""
732
  }
lightrag/operate.py CHANGED
@@ -405,12 +405,13 @@ async def local_query(
405
  kw_prompt = kw_prompt_temp.format(query=query)
406
  result = await use_model_func(kw_prompt)
407
  json_text = locate_json_string_body_from_string(result)
408
-
409
  try:
410
  keywords_data = json.loads(json_text)
411
  keywords = keywords_data.get("low_level_keywords", [])
412
  keywords = ", ".join(keywords)
413
  except json.JSONDecodeError:
 
414
  try:
415
  result = (
416
  result.replace(kw_prompt[:-1], "")
@@ -443,6 +444,8 @@ async def local_query(
443
  sys_prompt = sys_prompt_temp.format(
444
  context_data=context, response_type=query_param.response_type
445
  )
 
 
446
  response = await use_model_func(
447
  query,
448
  system_prompt=sys_prompt,
@@ -672,12 +675,12 @@ async def global_query(
672
  kw_prompt = kw_prompt_temp.format(query=query)
673
  result = await use_model_func(kw_prompt)
674
  json_text = locate_json_string_body_from_string(result)
675
-
676
  try:
677
  keywords_data = json.loads(json_text)
678
  keywords = keywords_data.get("high_level_keywords", [])
679
  keywords = ", ".join(keywords)
680
- except json.JSONDecodeError:
681
  try:
682
  result = (
683
  result.replace(kw_prompt[:-1], "")
@@ -714,6 +717,8 @@ async def global_query(
714
  sys_prompt = sys_prompt_temp.format(
715
  context_data=context, response_type=query_param.response_type
716
  )
 
 
717
  response = await use_model_func(
718
  query,
719
  system_prompt=sys_prompt,
@@ -914,6 +919,7 @@ async def hybrid_query(
914
 
915
  result = await use_model_func(kw_prompt)
916
  json_text = locate_json_string_body_from_string(result)
 
917
  try:
918
  keywords_data = json.loads(json_text)
919
  hl_keywords = keywords_data.get("high_level_keywords", [])
@@ -969,6 +975,8 @@ async def hybrid_query(
969
  sys_prompt = sys_prompt_temp.format(
970
  context_data=context, response_type=query_param.response_type
971
  )
 
 
972
  response = await use_model_func(
973
  query,
974
  system_prompt=sys_prompt,
@@ -1079,6 +1087,8 @@ async def naive_query(
1079
  sys_prompt = sys_prompt_temp.format(
1080
  content_data=section, response_type=query_param.response_type
1081
  )
 
 
1082
  response = await use_model_func(
1083
  query,
1084
  system_prompt=sys_prompt,
 
405
  kw_prompt = kw_prompt_temp.format(query=query)
406
  result = await use_model_func(kw_prompt)
407
  json_text = locate_json_string_body_from_string(result)
408
+ logger.debug("local_query json_text:", json_text)
409
  try:
410
  keywords_data = json.loads(json_text)
411
  keywords = keywords_data.get("low_level_keywords", [])
412
  keywords = ", ".join(keywords)
413
  except json.JSONDecodeError:
414
+ print(result)
415
  try:
416
  result = (
417
  result.replace(kw_prompt[:-1], "")
 
444
  sys_prompt = sys_prompt_temp.format(
445
  context_data=context, response_type=query_param.response_type
446
  )
447
+ if query_param.only_need_prompt:
448
+ return sys_prompt
449
  response = await use_model_func(
450
  query,
451
  system_prompt=sys_prompt,
 
675
  kw_prompt = kw_prompt_temp.format(query=query)
676
  result = await use_model_func(kw_prompt)
677
  json_text = locate_json_string_body_from_string(result)
678
+ logger.debug("global json_text:", json_text)
679
  try:
680
  keywords_data = json.loads(json_text)
681
  keywords = keywords_data.get("high_level_keywords", [])
682
  keywords = ", ".join(keywords)
683
+ except json.JSONDecodeError:
684
  try:
685
  result = (
686
  result.replace(kw_prompt[:-1], "")
 
717
  sys_prompt = sys_prompt_temp.format(
718
  context_data=context, response_type=query_param.response_type
719
  )
720
+ if query_param.only_need_prompt:
721
+ return sys_prompt
722
  response = await use_model_func(
723
  query,
724
  system_prompt=sys_prompt,
 
919
 
920
  result = await use_model_func(kw_prompt)
921
  json_text = locate_json_string_body_from_string(result)
922
+ logger.debug("hybrid_query json_text:", json_text)
923
  try:
924
  keywords_data = json.loads(json_text)
925
  hl_keywords = keywords_data.get("high_level_keywords", [])
 
975
  sys_prompt = sys_prompt_temp.format(
976
  context_data=context, response_type=query_param.response_type
977
  )
978
+ if query_param.only_need_prompt:
979
+ return sys_prompt
980
  response = await use_model_func(
981
  query,
982
  system_prompt=sys_prompt,
 
1087
  sys_prompt = sys_prompt_temp.format(
1088
  content_data=section, response_type=query_param.response_type
1089
  )
1090
+ if query_param.only_need_prompt:
1091
+ return sys_prompt
1092
  response = await use_model_func(
1093
  query,
1094
  system_prompt=sys_prompt,
lightrag/utils.py CHANGED
@@ -49,7 +49,11 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
49
  """Locate the JSON string body from a string"""
50
  maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
51
  if maybe_json_str is not None:
52
- return maybe_json_str.group(0)
 
 
 
 
53
  else:
54
  return None
55
 
 
49
  """Locate the JSON string body from a string"""
50
  maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
51
  if maybe_json_str is not None:
52
+ maybe_json_str = maybe_json_str.group(0)
53
+ maybe_json_str = maybe_json_str.replace("\\n", "")
54
+ maybe_json_str = maybe_json_str.replace("\n", "")
55
+ maybe_json_str = maybe_json_str.replace("'", '"')
56
+ return maybe_json_str
57
  else:
58
  return None
59