yangdx commited on
Commit
18efa1c
·
1 Parent(s): 706f457

Fix linting

Browse files
Files changed (1) hide show
  1. lightrag/kg/neo4j_impl.py +56 -26
lightrag/kg/neo4j_impl.py CHANGED
@@ -195,14 +195,12 @@ class Neo4JStorage(BaseGraphStorage):
195
  ) as session:
196
  try:
197
  query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
198
- result = await session.run(query, entity_id = node_id)
199
  single_result = await result.single()
200
  await result.consume() # Ensure result is fully consumed
201
  return single_result["node_exists"]
202
  except Exception as e:
203
- logger.error(
204
- f"Error checking node existence for {node_id}: {str(e)}"
205
- )
206
  await result.consume() # Ensure results are consumed even on error
207
  raise
208
 
@@ -229,7 +227,11 @@ class Neo4JStorage(BaseGraphStorage):
229
  "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
230
  "RETURN COUNT(r) > 0 AS edgeExists"
231
  )
232
- result = await session.run(query, source_entity_id = source_node_id, target_entity_id = target_node_id)
 
 
 
 
233
  single_result = await result.single()
234
  await result.consume() # Ensure result is fully consumed
235
  return single_result["edgeExists"]
@@ -274,7 +276,11 @@ class Neo4JStorage(BaseGraphStorage):
274
  node_dict = dict(node)
275
  # Remove base label from labels list if it exists
276
  if "labels" in node_dict:
277
- node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
 
 
 
 
278
  logger.debug(f"Neo4j query node {query} return: {node_dict}")
279
  return node_dict
280
  return None
@@ -308,25 +314,23 @@ class Neo4JStorage(BaseGraphStorage):
308
  OPTIONAL MATCH (n)-[r]-()
309
  RETURN COUNT(r) AS degree
310
  """
311
- result = await session.run(query, entity_id = node_id)
312
  try:
313
  record = await result.single()
314
 
315
  if not record:
316
- logger.warning(
317
- f"No node found with label '{node_id}'"
318
- )
319
  return 0
320
 
321
  degree = record["degree"]
322
- logger.debug("Neo4j query node degree for {node_id} return: {degree}")
 
 
323
  return degree
324
  finally:
325
  await result.consume() # Ensure result is fully consumed
326
  except Exception as e:
327
- logger.error(
328
- f"Error getting node degree for {node_id}: {str(e)}"
329
- )
330
  raise
331
 
332
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
@@ -373,7 +377,11 @@ class Neo4JStorage(BaseGraphStorage):
373
  MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
374
  RETURN properties(r) as edge_properties
375
  """
376
- result = await session.run(query, source_entity_id=source_node_id, target_entity_id=target_node_id)
 
 
 
 
377
  try:
378
  records = await result.fetch(2)
379
 
@@ -471,10 +479,14 @@ class Neo4JStorage(BaseGraphStorage):
471
  continue
472
 
473
  source_label = (
474
- source_node.get("entity_id") if source_node.get("entity_id") else None
 
 
475
  )
476
  target_label = (
477
- connected_node.get("entity_id") if connected_node.get("entity_id") else None
 
 
478
  )
479
 
480
  if source_label and target_label:
@@ -483,7 +495,9 @@ class Neo4JStorage(BaseGraphStorage):
483
  await results.consume() # Ensure results are consumed
484
  return edges
485
  except Exception as e:
486
- logger.error(f"Error getting edges for node {source_node_id}: {str(e)}")
 
 
487
  await results.consume() # Ensure results are consumed even on error
488
  raise
489
  except Exception as e:
@@ -520,11 +534,14 @@ class Neo4JStorage(BaseGraphStorage):
520
  async with self._driver.session(database=self._DATABASE) as session:
521
 
522
  async def execute_upsert(tx: AsyncManagedTransaction):
523
- query = """
 
524
  MERGE (n:base {entity_id: $properties.entity_id})
525
  SET n += $properties
526
  SET n:`%s`
527
- """ % entity_type
 
 
528
  result = await tx.run(query, properties=properties)
529
  logger.debug(
530
  f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
@@ -548,7 +565,6 @@ class Neo4JStorage(BaseGraphStorage):
548
  )
549
  ),
550
  )
551
-
552
  @retry(
553
  stop=stop_after_attempt(3),
554
  wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -728,7 +744,11 @@ class Neo4JStorage(BaseGraphStorage):
728
  result.nodes.append(
729
  KnowledgeGraphNode(
730
  id=f"{node_id}",
731
- labels=[label for label in node.labels if label != "base"],
 
 
 
 
732
  properties=dict(node),
733
  )
734
  )
@@ -767,7 +787,9 @@ class Neo4JStorage(BaseGraphStorage):
767
  logger.warning(
768
  "Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
769
  )
770
- return await self._robust_fallback(node_label, max_depth, min_degree)
 
 
771
 
772
  return result
773
 
@@ -843,7 +865,9 @@ class Neo4JStorage(BaseGraphStorage):
843
  # Create KnowledgeGraphNode for target
844
  target_node = KnowledgeGraphNode(
845
  id=f"{target_id}",
846
- labels=[label for label in b_node.labels if label != "base"],
 
 
847
  properties=dict(b_node.properties),
848
  )
849
 
@@ -883,7 +907,9 @@ class Neo4JStorage(BaseGraphStorage):
883
  # Create initial KnowledgeGraphNode
884
  start_node = KnowledgeGraphNode(
885
  id=f"{node_record['n'].get('entity_id')}",
886
- labels=[label for label in node_record["n"].labels if label != "base"],
 
 
887
  properties=dict(node_record["n"].properties),
888
  )
889
  finally:
@@ -942,6 +968,7 @@ class Neo4JStorage(BaseGraphStorage):
942
  Args:
943
  node_id: The label of the node to delete
944
  """
 
945
  async def _do_delete(tx: AsyncManagedTransaction):
946
  query = """
947
  MATCH (n:base {entity_id: $entity_id})
@@ -998,12 +1025,15 @@ class Neo4JStorage(BaseGraphStorage):
998
  edges: List of edges to be deleted, each edge is a (source, target) tuple
999
  """
1000
  for source, target in edges:
 
1001
  async def _do_delete_edge(tx: AsyncManagedTransaction):
1002
  query = """
1003
  MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
1004
  DELETE r
1005
  """
1006
- result = await tx.run(query, source_entity_id=source, target_entity_id=target)
 
 
1007
  logger.debug(f"Deleted edge from '{source}' to '{target}'")
1008
  await result.consume() # Ensure result is fully consumed
1009
 
 
195
  ) as session:
196
  try:
197
  query = "MATCH (n:base {entity_id: $entity_id}) RETURN count(n) > 0 AS node_exists"
198
+ result = await session.run(query, entity_id=node_id)
199
  single_result = await result.single()
200
  await result.consume() # Ensure result is fully consumed
201
  return single_result["node_exists"]
202
  except Exception as e:
203
+ logger.error(f"Error checking node existence for {node_id}: {str(e)}")
 
 
204
  await result.consume() # Ensure results are consumed even on error
205
  raise
206
 
 
227
  "MATCH (a:base {entity_id: $source_entity_id})-[r]-(b:base {entity_id: $target_entity_id}) "
228
  "RETURN COUNT(r) > 0 AS edgeExists"
229
  )
230
+ result = await session.run(
231
+ query,
232
+ source_entity_id=source_node_id,
233
+ target_entity_id=target_node_id,
234
+ )
235
  single_result = await result.single()
236
  await result.consume() # Ensure result is fully consumed
237
  return single_result["edgeExists"]
 
276
  node_dict = dict(node)
277
  # Remove base label from labels list if it exists
278
  if "labels" in node_dict:
279
+ node_dict["labels"] = [
280
+ label
281
+ for label in node_dict["labels"]
282
+ if label != "base"
283
+ ]
284
  logger.debug(f"Neo4j query node {query} return: {node_dict}")
285
  return node_dict
286
  return None
 
314
  OPTIONAL MATCH (n)-[r]-()
315
  RETURN COUNT(r) AS degree
316
  """
317
+ result = await session.run(query, entity_id=node_id)
318
  try:
319
  record = await result.single()
320
 
321
  if not record:
322
+ logger.warning(f"No node found with label '{node_id}'")
 
 
323
  return 0
324
 
325
  degree = record["degree"]
326
+ logger.debug(
327
+ "Neo4j query node degree for {node_id} return: {degree}"
328
+ )
329
  return degree
330
  finally:
331
  await result.consume() # Ensure result is fully consumed
332
  except Exception as e:
333
+ logger.error(f"Error getting node degree for {node_id}: {str(e)}")
 
 
334
  raise
335
 
336
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
 
377
  MATCH (start:base {entity_id: $source_entity_id})-[r]-(end:base {entity_id: $target_entity_id})
378
  RETURN properties(r) as edge_properties
379
  """
380
+ result = await session.run(
381
+ query,
382
+ source_entity_id=source_node_id,
383
+ target_entity_id=target_node_id,
384
+ )
385
  try:
386
  records = await result.fetch(2)
387
 
 
479
  continue
480
 
481
  source_label = (
482
+ source_node.get("entity_id")
483
+ if source_node.get("entity_id")
484
+ else None
485
  )
486
  target_label = (
487
+ connected_node.get("entity_id")
488
+ if connected_node.get("entity_id")
489
+ else None
490
  )
491
 
492
  if source_label and target_label:
 
495
  await results.consume() # Ensure results are consumed
496
  return edges
497
  except Exception as e:
498
+ logger.error(
499
+ f"Error getting edges for node {source_node_id}: {str(e)}"
500
+ )
501
  await results.consume() # Ensure results are consumed even on error
502
  raise
503
  except Exception as e:
 
534
  async with self._driver.session(database=self._DATABASE) as session:
535
 
536
  async def execute_upsert(tx: AsyncManagedTransaction):
537
+ query = (
538
+ """
539
  MERGE (n:base {entity_id: $properties.entity_id})
540
  SET n += $properties
541
  SET n:`%s`
542
+ """
543
+ % entity_type
544
+ )
545
  result = await tx.run(query, properties=properties)
546
  logger.debug(
547
  f"Upserted node with entity_id '{entity_id}' and properties: {properties}"
 
565
  )
566
  ),
567
  )
 
568
  @retry(
569
  stop=stop_after_attempt(3),
570
  wait=wait_exponential(multiplier=1, min=4, max=10),
 
744
  result.nodes.append(
745
  KnowledgeGraphNode(
746
  id=f"{node_id}",
747
+ labels=[
748
+ label
749
+ for label in node.labels
750
+ if label != "base"
751
+ ],
752
  properties=dict(node),
753
  )
754
  )
 
787
  logger.warning(
788
  "Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
789
  )
790
+ return await self._robust_fallback(
791
+ node_label, max_depth, min_degree
792
+ )
793
 
794
  return result
795
 
 
865
  # Create KnowledgeGraphNode for target
866
  target_node = KnowledgeGraphNode(
867
  id=f"{target_id}",
868
+ labels=[
869
+ label for label in b_node.labels if label != "base"
870
+ ],
871
  properties=dict(b_node.properties),
872
  )
873
 
 
907
  # Create initial KnowledgeGraphNode
908
  start_node = KnowledgeGraphNode(
909
  id=f"{node_record['n'].get('entity_id')}",
910
+ labels=[
911
+ label for label in node_record["n"].labels if label != "base"
912
+ ],
913
  properties=dict(node_record["n"].properties),
914
  )
915
  finally:
 
968
  Args:
969
  node_id: The label of the node to delete
970
  """
971
+
972
  async def _do_delete(tx: AsyncManagedTransaction):
973
  query = """
974
  MATCH (n:base {entity_id: $entity_id})
 
1025
  edges: List of edges to be deleted, each edge is a (source, target) tuple
1026
  """
1027
  for source, target in edges:
1028
+
1029
  async def _do_delete_edge(tx: AsyncManagedTransaction):
1030
  query = """
1031
  MATCH (source:base {entity_id: $source_entity_id})-[r]-(target:base {entity_id: $target_entity_id})
1032
  DELETE r
1033
  """
1034
+ result = await tx.run(
1035
+ query, source_entity_id=source, target_entity_id=target
1036
+ )
1037
  logger.debug(f"Deleted edge from '{source}' to '{target}'")
1038
  await result.consume() # Ensure result is fully consumed
1039