yangdx commited on
Commit
83b5082
·
1 Parent(s): 16e7647

Fix linting

Browse files
lightrag/base.py CHANGED
@@ -363,7 +363,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
363
 
364
  async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
365
  """Get nodes as a batch using UNWIND
366
-
367
  Default implementation fetches nodes one by one.
368
  Override this method for better performance in storage backends
369
  that support batch operations.
@@ -377,7 +377,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
377
 
378
  async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
379
  """Node degrees as a batch using UNWIND
380
-
381
  Default implementation fetches node degrees one by one.
382
  Override this method for better performance in storage backends
383
  that support batch operations.
@@ -388,9 +388,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
388
  result[node_id] = degree
389
  return result
390
 
391
- async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
 
 
392
  """Edge degrees as a batch using UNWIND also uses node_degrees_batch
393
-
394
  Default implementation calculates edge degrees one by one.
395
  Override this method for better performance in storage backends
396
  that support batch operations.
@@ -401,9 +403,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
401
  result[(src_id, tgt_id)] = degree
402
  return result
403
 
404
- async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
 
 
405
  """Get edges as a batch using UNWIND
406
-
407
  Default implementation fetches edges one by one.
408
  Override this method for better performance in storage backends
409
  that support batch operations.
@@ -417,9 +421,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
417
  result[(src_id, tgt_id)] = edge
418
  return result
419
 
420
- async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
 
 
421
  """Get nodes edges as a batch using UNWIND
422
-
423
  Default implementation fetches node edges one by one.
424
  Override this method for better performance in storage backends
425
  that support batch operations.
 
363
 
364
  async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
365
  """Get nodes as a batch using UNWIND
366
+
367
  Default implementation fetches nodes one by one.
368
  Override this method for better performance in storage backends
369
  that support batch operations.
 
377
 
378
  async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
379
  """Node degrees as a batch using UNWIND
380
+
381
  Default implementation fetches node degrees one by one.
382
  Override this method for better performance in storage backends
383
  that support batch operations.
 
388
  result[node_id] = degree
389
  return result
390
 
391
+ async def edge_degrees_batch(
392
+ self, edge_pairs: list[tuple[str, str]]
393
+ ) -> dict[tuple[str, str], int]:
394
  """Edge degrees as a batch using UNWIND also uses node_degrees_batch
395
+
396
  Default implementation calculates edge degrees one by one.
397
  Override this method for better performance in storage backends
398
  that support batch operations.
 
403
  result[(src_id, tgt_id)] = degree
404
  return result
405
 
406
+ async def get_edges_batch(
407
+ self, pairs: list[dict[str, str]]
408
+ ) -> dict[tuple[str, str], dict]:
409
  """Get edges as a batch using UNWIND
410
+
411
  Default implementation fetches edges one by one.
412
  Override this method for better performance in storage backends
413
  that support batch operations.
 
421
  result[(src_id, tgt_id)] = edge
422
  return result
423
 
424
+ async def get_nodes_edges_batch(
425
+ self, node_ids: list[str]
426
+ ) -> dict[str, list[tuple[str, str]]]:
427
  """Get nodes edges as a batch using UNWIND
428
+
429
  Default implementation fetches node edges one by one.
430
  Override this method for better performance in storage backends
431
  that support batch operations.
lightrag/kg/neo4j_impl.py CHANGED
@@ -311,10 +311,10 @@ class Neo4JStorage(BaseGraphStorage):
311
  async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
312
  """
313
  Retrieve multiple nodes in one query using UNWIND.
314
-
315
  Args:
316
  node_ids: List of node entity IDs to fetch.
317
-
318
  Returns:
319
  A dictionary mapping each node_id to its node data (or None if not found).
320
  """
@@ -334,7 +334,9 @@ class Neo4JStorage(BaseGraphStorage):
334
  node_dict = dict(node)
335
  # Remove the 'base' label if present in a 'labels' property
336
  if "labels" in node_dict:
337
- node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
 
 
338
  nodes[entity_id] = node_dict
339
  await result.consume() # Make sure to consume the result fully
340
  return nodes
@@ -385,12 +387,12 @@ class Neo4JStorage(BaseGraphStorage):
385
  async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
386
  """
387
  Retrieve the degree for multiple nodes in a single query using UNWIND.
388
-
389
  Args:
390
  node_ids: List of node labels (entity_id values) to look up.
391
-
392
  Returns:
393
- A dictionary mapping each node_id to its degree (number of relationships).
394
  If a node is not found, its degree will be set to 0.
395
  """
396
  async with self._driver.session(
@@ -407,13 +409,13 @@ class Neo4JStorage(BaseGraphStorage):
407
  entity_id = record["entity_id"]
408
  degrees[entity_id] = record["degree"]
409
  await result.consume() # Ensure result is fully consumed
410
-
411
  # For any node_id that did not return a record, set degree to 0.
412
  for nid in node_ids:
413
  if nid not in degrees:
414
  logger.warning(f"No node found with label '{nid}'")
415
  degrees[nid] = 0
416
-
417
  logger.debug(f"Neo4j batch node degree query returned: {degrees}")
418
  return degrees
419
 
@@ -436,25 +438,27 @@ class Neo4JStorage(BaseGraphStorage):
436
 
437
  degrees = int(src_degree) + int(trg_degree)
438
  return degrees
439
-
440
- async def edge_degrees_batch(self, edge_pairs: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
 
 
441
  """
442
  Calculate the combined degree for each edge (sum of the source and target node degrees)
443
  in batch using the already implemented node_degrees_batch.
444
-
445
  Args:
446
  edge_pairs: List of (src, tgt) tuples.
447
-
448
  Returns:
449
  A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
450
  """
451
  # Collect unique node IDs from all edge pairs.
452
  unique_node_ids = {src for src, _ in edge_pairs}
453
  unique_node_ids.update({tgt for _, tgt in edge_pairs})
454
-
455
  # Get degrees for all nodes in one go.
456
  degrees = await self.node_degrees_batch(list(unique_node_ids))
457
-
458
  # Sum up degrees for each edge pair.
459
  edge_degrees = {}
460
  for src, tgt in edge_pairs:
@@ -547,13 +551,15 @@ class Neo4JStorage(BaseGraphStorage):
547
  )
548
  raise
549
 
550
- async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
 
 
551
  """
552
  Retrieve edge properties for multiple (src, tgt) pairs in one query.
553
-
554
  Args:
555
  pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
556
-
557
  Returns:
558
  A dictionary mapping (src, tgt) tuples to their edge properties.
559
  """
@@ -574,13 +580,23 @@ class Neo4JStorage(BaseGraphStorage):
574
  if edges and len(edges) > 0:
575
  edge_props = edges[0] # choose the first if multiple exist
576
  # Ensure required keys exist with defaults
577
- for key, default in {"weight": 0.0, "source_id": None, "description": None, "keywords": None}.items():
 
 
 
 
 
578
  if key not in edge_props:
579
  edge_props[key] = default
580
  edges_dict[(src, tgt)] = edge_props
581
  else:
582
  # No edge found – set default edge properties
583
- edges_dict[(src, tgt)] = {"weight": 0.0, "source_id": None, "description": None, "keywords": None}
 
 
 
 
 
584
  await result.consume()
585
  return edges_dict
586
 
@@ -644,17 +660,21 @@ class Neo4JStorage(BaseGraphStorage):
644
  logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
645
  raise
646
 
647
- async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
 
 
648
  """
649
  Batch retrieve edges for multiple nodes in one query using UNWIND.
650
-
651
  Args:
652
  node_ids: List of node IDs (entity_id) for which to retrieve edges.
653
-
654
  Returns:
655
  A dictionary mapping each node ID to its list of edge tuples (source, target).
656
  """
657
- async with self._driver.session(database=self._DATABASE, default_access_mode="READ") as session:
 
 
658
  query = """
659
  UNWIND $node_ids AS id
660
  MATCH (n:base {entity_id: id})
 
311
  async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
312
  """
313
  Retrieve multiple nodes in one query using UNWIND.
314
+
315
  Args:
316
  node_ids: List of node entity IDs to fetch.
317
+
318
  Returns:
319
  A dictionary mapping each node_id to its node data (or None if not found).
320
  """
 
334
  node_dict = dict(node)
335
  # Remove the 'base' label if present in a 'labels' property
336
  if "labels" in node_dict:
337
+ node_dict["labels"] = [
338
+ label for label in node_dict["labels"] if label != "base"
339
+ ]
340
  nodes[entity_id] = node_dict
341
  await result.consume() # Make sure to consume the result fully
342
  return nodes
 
387
  async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
388
  """
389
  Retrieve the degree for multiple nodes in a single query using UNWIND.
390
+
391
  Args:
392
  node_ids: List of node labels (entity_id values) to look up.
393
+
394
  Returns:
395
+ A dictionary mapping each node_id to its degree (number of relationships).
396
  If a node is not found, its degree will be set to 0.
397
  """
398
  async with self._driver.session(
 
409
  entity_id = record["entity_id"]
410
  degrees[entity_id] = record["degree"]
411
  await result.consume() # Ensure result is fully consumed
412
+
413
  # For any node_id that did not return a record, set degree to 0.
414
  for nid in node_ids:
415
  if nid not in degrees:
416
  logger.warning(f"No node found with label '{nid}'")
417
  degrees[nid] = 0
418
+
419
  logger.debug(f"Neo4j batch node degree query returned: {degrees}")
420
  return degrees
421
 
 
438
 
439
  degrees = int(src_degree) + int(trg_degree)
440
  return degrees
441
+
442
+ async def edge_degrees_batch(
443
+ self, edge_pairs: list[tuple[str, str]]
444
+ ) -> dict[tuple[str, str], int]:
445
  """
446
  Calculate the combined degree for each edge (sum of the source and target node degrees)
447
  in batch using the already implemented node_degrees_batch.
448
+
449
  Args:
450
  edge_pairs: List of (src, tgt) tuples.
451
+
452
  Returns:
453
  A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
454
  """
455
  # Collect unique node IDs from all edge pairs.
456
  unique_node_ids = {src for src, _ in edge_pairs}
457
  unique_node_ids.update({tgt for _, tgt in edge_pairs})
458
+
459
  # Get degrees for all nodes in one go.
460
  degrees = await self.node_degrees_batch(list(unique_node_ids))
461
+
462
  # Sum up degrees for each edge pair.
463
  edge_degrees = {}
464
  for src, tgt in edge_pairs:
 
551
  )
552
  raise
553
 
554
+ async def get_edges_batch(
555
+ self, pairs: list[dict[str, str]]
556
+ ) -> dict[tuple[str, str], dict]:
557
  """
558
  Retrieve edge properties for multiple (src, tgt) pairs in one query.
559
+
560
  Args:
561
  pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
562
+
563
  Returns:
564
  A dictionary mapping (src, tgt) tuples to their edge properties.
565
  """
 
580
  if edges and len(edges) > 0:
581
  edge_props = edges[0] # choose the first if multiple exist
582
  # Ensure required keys exist with defaults
583
+ for key, default in {
584
+ "weight": 0.0,
585
+ "source_id": None,
586
+ "description": None,
587
+ "keywords": None,
588
+ }.items():
589
  if key not in edge_props:
590
  edge_props[key] = default
591
  edges_dict[(src, tgt)] = edge_props
592
  else:
593
  # No edge found – set default edge properties
594
+ edges_dict[(src, tgt)] = {
595
+ "weight": 0.0,
596
+ "source_id": None,
597
+ "description": None,
598
+ "keywords": None,
599
+ }
600
  await result.consume()
601
  return edges_dict
602
 
 
660
  logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
661
  raise
662
 
663
+ async def get_nodes_edges_batch(
664
+ self, node_ids: list[str]
665
+ ) -> dict[str, list[tuple[str, str]]]:
666
  """
667
  Batch retrieve edges for multiple nodes in one query using UNWIND.
668
+
669
  Args:
670
  node_ids: List of node IDs (entity_id) for which to retrieve edges.
671
+
672
  Returns:
673
  A dictionary mapping each node ID to its list of edge tuples (source, target).
674
  """
675
+ async with self._driver.session(
676
+ database=self._DATABASE, default_access_mode="READ"
677
+ ) as session:
678
  query = """
679
  UNWIND $node_ids AS id
680
  MATCH (n:base {entity_id: id})
lightrag/kg/postgres_impl.py CHANGED
@@ -1461,30 +1461,29 @@ class PGGraphStorage(BaseGraphStorage):
1461
  async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
1462
  """
1463
  Retrieve multiple nodes in one query using UNWIND.
1464
-
1465
  Args:
1466
  node_ids: List of node entity IDs to fetch.
1467
-
1468
  Returns:
1469
  A dictionary mapping each node_id to its node data (or None if not found).
1470
  """
1471
  if not node_ids:
1472
  return {}
1473
-
1474
  # Format node IDs for the query
1475
- formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
1476
-
 
 
1477
  query = """SELECT * FROM cypher('%s', $$
1478
  UNWIND [%s] AS node_id
1479
  MATCH (n:base {entity_id: node_id})
1480
  RETURN node_id, n
1481
- $$) AS (node_id text, n agtype)""" % (
1482
- self.graph_name,
1483
- formatted_ids
1484
- )
1485
-
1486
  results = await self._query(query)
1487
-
1488
  # Build result dictionary
1489
  nodes_dict = {}
1490
  for result in results:
@@ -1492,28 +1491,32 @@ class PGGraphStorage(BaseGraphStorage):
1492
  node_dict = result["n"]["properties"]
1493
  # Remove the 'base' label if present in a 'labels' property
1494
  if "labels" in node_dict:
1495
- node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
 
 
1496
  nodes_dict[result["node_id"]] = node_dict
1497
-
1498
  return nodes_dict
1499
 
1500
  async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
1501
  """
1502
  Retrieve the degree for multiple nodes in a single query using UNWIND.
1503
-
1504
  Args:
1505
  node_ids: List of node labels (entity_id values) to look up.
1506
-
1507
  Returns:
1508
- A dictionary mapping each node_id to its degree (number of relationships).
1509
  If a node is not found, its degree will be set to 0.
1510
  """
1511
  if not node_ids:
1512
  return {}
1513
-
1514
  # Format node IDs for the query
1515
- formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
1516
-
 
 
1517
  query = """SELECT * FROM cypher('%s', $$
1518
  UNWIND [%s] AS node_id
1519
  MATCH (n:base {entity_id: node_id})
@@ -1521,112 +1524,122 @@ class PGGraphStorage(BaseGraphStorage):
1521
  RETURN node_id, count(r) AS degree
1522
  $$) AS (node_id text, degree bigint)""" % (
1523
  self.graph_name,
1524
- formatted_ids
1525
  )
1526
-
1527
  results = await self._query(query)
1528
-
1529
  # Build result dictionary
1530
  degrees_dict = {}
1531
  for result in results:
1532
  if result["node_id"] is not None:
1533
  degrees_dict[result["node_id"]] = int(result["degree"])
1534
-
1535
  # Ensure all requested node_ids are in the result dictionary
1536
  for node_id in node_ids:
1537
  if node_id not in degrees_dict:
1538
  degrees_dict[node_id] = 0
1539
-
1540
  return degrees_dict
1541
-
1542
- async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
 
 
1543
  """
1544
  Calculate the combined degree for each edge (sum of the source and target node degrees)
1545
  in batch using the already implemented node_degrees_batch.
1546
-
1547
  Args:
1548
  edges: List of (source_node_id, target_node_id) tuples
1549
-
1550
  Returns:
1551
  Dictionary mapping edge tuples to their combined degrees
1552
  """
1553
  if not edges:
1554
  return {}
1555
-
1556
  # Use node_degrees_batch to get all node degrees efficiently
1557
  all_nodes = set()
1558
  for src, tgt in edges:
1559
  all_nodes.add(src)
1560
  all_nodes.add(tgt)
1561
-
1562
  node_degrees = await self.node_degrees_batch(list(all_nodes))
1563
-
1564
  # Calculate edge degrees
1565
  edge_degrees_dict = {}
1566
  for src, tgt in edges:
1567
  src_degree = node_degrees.get(src, 0)
1568
  tgt_degree = node_degrees.get(tgt, 0)
1569
  edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
1570
-
1571
  return edge_degrees_dict
1572
-
1573
- async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
 
 
1574
  """
1575
  Retrieve edge properties for multiple (src, tgt) pairs in one query.
1576
-
1577
  Args:
1578
  pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
1579
-
1580
  Returns:
1581
  A dictionary mapping (src, tgt) tuples to their edge properties.
1582
  """
1583
  if not pairs:
1584
  return {}
1585
-
1586
  # 从字典列表中提取源节点和目标节点ID
1587
  src_nodes = []
1588
  tgt_nodes = []
1589
  for pair in pairs:
1590
- src_nodes.append(pair["src"].replace('"', ''))
1591
- tgt_nodes.append(pair["tgt"].replace('"', ''))
1592
-
1593
  # 构建查询,使用数组索引来匹配源节点和目标节点
1594
  src_array = ", ".join([f'"{src}"' for src in src_nodes])
1595
  tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
1596
-
1597
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1598
  WITH [{src_array}] AS sources, [{tgt_array}] AS targets
1599
  UNWIND range(0, size(sources)-1) AS i
1600
  MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
1601
  RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
1602
  $$) AS (source text, target text, edge_properties agtype)"""
1603
-
1604
  results = await self._query(query)
1605
-
1606
  # 构建结果字典
1607
  edges_dict = {}
1608
  for result in results:
1609
  if result["source"] and result["target"] and result["edge_properties"]:
1610
- edges_dict[(result["source"], result["target"])] = result["edge_properties"]
1611
-
 
 
1612
  return edges_dict
1613
-
1614
- async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
 
 
1615
  """
1616
  Get all edges for multiple nodes in a single batch operation.
1617
-
1618
  Args:
1619
  node_ids: List of node IDs to get edges for
1620
-
1621
  Returns:
1622
  Dictionary mapping node IDs to lists of (source, target) edge tuples
1623
  """
1624
  if not node_ids:
1625
  return {}
1626
-
1627
  # Format node IDs for the query
1628
- formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
1629
-
 
 
1630
  query = """SELECT * FROM cypher('%s', $$
1631
  UNWIND [%s] AS node_id
1632
  MATCH (n:base {entity_id: node_id})
@@ -1634,11 +1647,11 @@ class PGGraphStorage(BaseGraphStorage):
1634
  RETURN node_id, connected.entity_id AS connected_id
1635
  $$) AS (node_id text, connected_id text)""" % (
1636
  self.graph_name,
1637
- formatted_ids
1638
  )
1639
-
1640
  results = await self._query(query)
1641
-
1642
  # Build result dictionary
1643
  nodes_edges_dict = {node_id: [] for node_id in node_ids}
1644
  for result in results:
@@ -1646,9 +1659,9 @@ class PGGraphStorage(BaseGraphStorage):
1646
  nodes_edges_dict[result["node_id"]].append(
1647
  (result["node_id"], result["connected_id"])
1648
  )
1649
-
1650
  return nodes_edges_dict
1651
-
1652
  async def get_all_labels(self) -> list[str]:
1653
  """
1654
  Get all labels (node IDs) in the graph.
 
1461
  async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
1462
  """
1463
  Retrieve multiple nodes in one query using UNWIND.
1464
+
1465
  Args:
1466
  node_ids: List of node entity IDs to fetch.
1467
+
1468
  Returns:
1469
  A dictionary mapping each node_id to its node data (or None if not found).
1470
  """
1471
  if not node_ids:
1472
  return {}
1473
+
1474
  # Format node IDs for the query
1475
+ formatted_ids = ", ".join(
1476
+ ['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
1477
+ )
1478
+
1479
  query = """SELECT * FROM cypher('%s', $$
1480
  UNWIND [%s] AS node_id
1481
  MATCH (n:base {entity_id: node_id})
1482
  RETURN node_id, n
1483
+ $$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
1484
+
 
 
 
1485
  results = await self._query(query)
1486
+
1487
  # Build result dictionary
1488
  nodes_dict = {}
1489
  for result in results:
 
1491
  node_dict = result["n"]["properties"]
1492
  # Remove the 'base' label if present in a 'labels' property
1493
  if "labels" in node_dict:
1494
+ node_dict["labels"] = [
1495
+ label for label in node_dict["labels"] if label != "base"
1496
+ ]
1497
  nodes_dict[result["node_id"]] = node_dict
1498
+
1499
  return nodes_dict
1500
 
1501
  async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
1502
  """
1503
  Retrieve the degree for multiple nodes in a single query using UNWIND.
1504
+
1505
  Args:
1506
  node_ids: List of node labels (entity_id values) to look up.
1507
+
1508
  Returns:
1509
+ A dictionary mapping each node_id to its degree (number of relationships).
1510
  If a node is not found, its degree will be set to 0.
1511
  """
1512
  if not node_ids:
1513
  return {}
1514
+
1515
  # Format node IDs for the query
1516
+ formatted_ids = ", ".join(
1517
+ ['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
1518
+ )
1519
+
1520
  query = """SELECT * FROM cypher('%s', $$
1521
  UNWIND [%s] AS node_id
1522
  MATCH (n:base {entity_id: node_id})
 
1524
  RETURN node_id, count(r) AS degree
1525
  $$) AS (node_id text, degree bigint)""" % (
1526
  self.graph_name,
1527
+ formatted_ids,
1528
  )
1529
+
1530
  results = await self._query(query)
1531
+
1532
  # Build result dictionary
1533
  degrees_dict = {}
1534
  for result in results:
1535
  if result["node_id"] is not None:
1536
  degrees_dict[result["node_id"]] = int(result["degree"])
1537
+
1538
  # Ensure all requested node_ids are in the result dictionary
1539
  for node_id in node_ids:
1540
  if node_id not in degrees_dict:
1541
  degrees_dict[node_id] = 0
1542
+
1543
  return degrees_dict
1544
+
1545
+ async def edge_degrees_batch(
1546
+ self, edges: list[tuple[str, str]]
1547
+ ) -> dict[tuple[str, str], int]:
1548
  """
1549
  Calculate the combined degree for each edge (sum of the source and target node degrees)
1550
  in batch using the already implemented node_degrees_batch.
1551
+
1552
  Args:
1553
  edges: List of (source_node_id, target_node_id) tuples
1554
+
1555
  Returns:
1556
  Dictionary mapping edge tuples to their combined degrees
1557
  """
1558
  if not edges:
1559
  return {}
1560
+
1561
  # Use node_degrees_batch to get all node degrees efficiently
1562
  all_nodes = set()
1563
  for src, tgt in edges:
1564
  all_nodes.add(src)
1565
  all_nodes.add(tgt)
1566
+
1567
  node_degrees = await self.node_degrees_batch(list(all_nodes))
1568
+
1569
  # Calculate edge degrees
1570
  edge_degrees_dict = {}
1571
  for src, tgt in edges:
1572
  src_degree = node_degrees.get(src, 0)
1573
  tgt_degree = node_degrees.get(tgt, 0)
1574
  edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
1575
+
1576
  return edge_degrees_dict
1577
+
1578
+ async def get_edges_batch(
1579
+ self, pairs: list[dict[str, str]]
1580
+ ) -> dict[tuple[str, str], dict]:
1581
  """
1582
  Retrieve edge properties for multiple (src, tgt) pairs in one query.
1583
+
1584
  Args:
1585
  pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
1586
+
1587
  Returns:
1588
  A dictionary mapping (src, tgt) tuples to their edge properties.
1589
  """
1590
  if not pairs:
1591
  return {}
1592
+
1593
  # 从字典列表中提取源节点和目标节点ID
1594
  src_nodes = []
1595
  tgt_nodes = []
1596
  for pair in pairs:
1597
+ src_nodes.append(pair["src"].replace('"', ""))
1598
+ tgt_nodes.append(pair["tgt"].replace('"', ""))
1599
+
1600
  # 构建查询,使用数组索引来匹配源节点和目标节点
1601
  src_array = ", ".join([f'"{src}"' for src in src_nodes])
1602
  tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
1603
+
1604
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1605
  WITH [{src_array}] AS sources, [{tgt_array}] AS targets
1606
  UNWIND range(0, size(sources)-1) AS i
1607
  MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
1608
  RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
1609
  $$) AS (source text, target text, edge_properties agtype)"""
1610
+
1611
  results = await self._query(query)
1612
+
1613
  # 构建结果字典
1614
  edges_dict = {}
1615
  for result in results:
1616
  if result["source"] and result["target"] and result["edge_properties"]:
1617
+ edges_dict[(result["source"], result["target"])] = result[
1618
+ "edge_properties"
1619
+ ]
1620
+
1621
  return edges_dict
1622
+
1623
+ async def get_nodes_edges_batch(
1624
+ self, node_ids: list[str]
1625
+ ) -> dict[str, list[tuple[str, str]]]:
1626
  """
1627
  Get all edges for multiple nodes in a single batch operation.
1628
+
1629
  Args:
1630
  node_ids: List of node IDs to get edges for
1631
+
1632
  Returns:
1633
  Dictionary mapping node IDs to lists of (source, target) edge tuples
1634
  """
1635
  if not node_ids:
1636
  return {}
1637
+
1638
  # Format node IDs for the query
1639
+ formatted_ids = ", ".join(
1640
+ ['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
1641
+ )
1642
+
1643
  query = """SELECT * FROM cypher('%s', $$
1644
  UNWIND [%s] AS node_id
1645
  MATCH (n:base {entity_id: node_id})
 
1647
  RETURN node_id, connected.entity_id AS connected_id
1648
  $$) AS (node_id text, connected_id text)""" % (
1649
  self.graph_name,
1650
+ formatted_ids,
1651
  )
1652
+
1653
  results = await self._query(query)
1654
+
1655
  # Build result dictionary
1656
  nodes_edges_dict = {node_id: [] for node_id in node_ids}
1657
  for result in results:
 
1659
  nodes_edges_dict[result["node_id"]].append(
1660
  (result["node_id"], result["connected_id"])
1661
  )
1662
+
1663
  return nodes_edges_dict
1664
+
1665
  async def get_all_labels(self) -> list[str]:
1666
  """
1667
  Get all labels (node IDs) in the graph.
lightrag/operate.py CHANGED
@@ -1323,14 +1323,14 @@ async def _get_node_data(
1323
 
1324
  if not len(results):
1325
  return "", "", ""
1326
-
1327
  # Extract all entity IDs from your results list
1328
  node_ids = [r["entity_name"] for r in results]
1329
 
1330
  # Call the batch node retrieval and degree functions concurrently.
1331
  nodes_dict, degrees_dict = await asyncio.gather(
1332
- knowledge_graph_inst.get_nodes_batch(node_ids),
1333
- knowledge_graph_inst.node_degrees_batch(node_ids)
1334
  )
1335
 
1336
  # Now, if you need the node data and degree in order:
@@ -1459,7 +1459,7 @@ async def _find_most_related_text_unit_from_entities(
1459
  for dp in node_datas
1460
  if dp["source_id"] is not None
1461
  ]
1462
-
1463
  node_names = [dp["entity_name"] for dp in node_datas]
1464
  batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
1465
  # Build the edges list in the same order as node_datas.
@@ -1472,10 +1472,14 @@ async def _find_most_related_text_unit_from_entities(
1472
  all_one_hop_nodes.update([e[1] for e in this_edges])
1473
 
1474
  all_one_hop_nodes = list(all_one_hop_nodes)
1475
-
1476
  # Batch retrieve one-hop node data using get_nodes_batch
1477
- all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(all_one_hop_nodes)
1478
- all_one_hop_nodes_data = [all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes]
 
 
 
 
1479
 
1480
  # Add null check for node data
1481
  all_one_hop_text_units_lookup = {
@@ -1571,13 +1575,13 @@ async def _find_most_related_edges_from_entities(
1571
  edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
1572
  # For edge degrees, use tuples.
1573
  edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
1574
-
1575
  # Call the batched functions concurrently.
1576
  edge_data_dict, edge_degrees_dict = await asyncio.gather(
1577
  knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
1578
- knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
1579
  )
1580
-
1581
  # Reconstruct edge_datas list in the same order as the deduplicated results.
1582
  all_edges_data = []
1583
  for pair in all_edges:
@@ -1590,7 +1594,6 @@ async def _find_most_related_edges_from_entities(
1590
  }
1591
  all_edges_data.append(combined)
1592
 
1593
-
1594
  all_edges_data = sorted(
1595
  all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
1596
  )
@@ -1634,7 +1637,7 @@ async def _get_edge_data(
1634
  # Call the batched functions concurrently.
1635
  edge_data_dict, edge_degrees_dict = await asyncio.gather(
1636
  knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
1637
- knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
1638
  )
1639
 
1640
  # Reconstruct edge_datas list in the same order as results.
@@ -1652,7 +1655,7 @@ async def _get_edge_data(
1652
  **edge_props,
1653
  }
1654
  edge_datas.append(combined)
1655
-
1656
  edge_datas = sorted(
1657
  edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
1658
  )
@@ -1761,7 +1764,7 @@ async def _find_most_related_entities_from_relationships(
1761
  # Batch approach: Retrieve nodes and their degrees concurrently with one query each.
1762
  nodes_dict, degrees_dict = await asyncio.gather(
1763
  knowledge_graph_inst.get_nodes_batch(entity_names),
1764
- knowledge_graph_inst.node_degrees_batch(entity_names)
1765
  )
1766
 
1767
  # Rebuild the list in the same order as entity_names
 
1323
 
1324
  if not len(results):
1325
  return "", "", ""
1326
+
1327
  # Extract all entity IDs from your results list
1328
  node_ids = [r["entity_name"] for r in results]
1329
 
1330
  # Call the batch node retrieval and degree functions concurrently.
1331
  nodes_dict, degrees_dict = await asyncio.gather(
1332
+ knowledge_graph_inst.get_nodes_batch(node_ids),
1333
+ knowledge_graph_inst.node_degrees_batch(node_ids),
1334
  )
1335
 
1336
  # Now, if you need the node data and degree in order:
 
1459
  for dp in node_datas
1460
  if dp["source_id"] is not None
1461
  ]
1462
+
1463
  node_names = [dp["entity_name"] for dp in node_datas]
1464
  batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
1465
  # Build the edges list in the same order as node_datas.
 
1472
  all_one_hop_nodes.update([e[1] for e in this_edges])
1473
 
1474
  all_one_hop_nodes = list(all_one_hop_nodes)
1475
+
1476
  # Batch retrieve one-hop node data using get_nodes_batch
1477
+ all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(
1478
+ all_one_hop_nodes
1479
+ )
1480
+ all_one_hop_nodes_data = [
1481
+ all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes
1482
+ ]
1483
 
1484
  # Add null check for node data
1485
  all_one_hop_text_units_lookup = {
 
1575
  edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
1576
  # For edge degrees, use tuples.
1577
  edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
1578
+
1579
  # Call the batched functions concurrently.
1580
  edge_data_dict, edge_degrees_dict = await asyncio.gather(
1581
  knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
1582
+ knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
1583
  )
1584
+
1585
  # Reconstruct edge_datas list in the same order as the deduplicated results.
1586
  all_edges_data = []
1587
  for pair in all_edges:
 
1594
  }
1595
  all_edges_data.append(combined)
1596
 
 
1597
  all_edges_data = sorted(
1598
  all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
1599
  )
 
1637
  # Call the batched functions concurrently.
1638
  edge_data_dict, edge_degrees_dict = await asyncio.gather(
1639
  knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
1640
+ knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
1641
  )
1642
 
1643
  # Reconstruct edge_datas list in the same order as results.
 
1655
  **edge_props,
1656
  }
1657
  edge_datas.append(combined)
1658
+
1659
  edge_datas = sorted(
1660
  edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
1661
  )
 
1764
  # Batch approach: Retrieve nodes and their degrees concurrently with one query each.
1765
  nodes_dict, degrees_dict = await asyncio.gather(
1766
  knowledge_graph_inst.get_nodes_batch(entity_names),
1767
+ knowledge_graph_inst.node_degrees_batch(entity_names),
1768
  )
1769
 
1770
  # Rebuild the list in the same order as entity_names
lightrag_webui/src/stores/graph.ts CHANGED
@@ -136,7 +136,7 @@ interface GraphState {
136
  // Version counter to trigger data refresh
137
  graphDataVersion: number
138
  incrementGraphDataVersion: () => void
139
-
140
  // Methods for updating graph elements and UI state together
141
  updateNodeAndSelect: (nodeId: string, entityId: string, propertyName: string, newValue: string) => Promise<void>
142
  updateEdgeAndSelect: (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => Promise<void>
@@ -252,40 +252,40 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
252
  // Get current state
253
  const state = get()
254
  const { sigmaGraph, rawGraph } = state
255
-
256
  // Validate graph state
257
  if (!sigmaGraph || !rawGraph || !sigmaGraph.hasNode(nodeId)) {
258
  return
259
  }
260
-
261
  try {
262
  const nodeAttributes = sigmaGraph.getNodeAttributes(nodeId)
263
-
264
  console.log('updateNodeAndSelect', nodeId, entityId, propertyName, newValue)
265
-
266
  // For entity_id changes (node renaming) with NetworkX graph storage
267
  if ((nodeId === entityId) && (propertyName === 'entity_id')) {
268
  // Create new node with updated ID but same attributes
269
  sigmaGraph.addNode(newValue, { ...nodeAttributes, label: newValue })
270
-
271
  const edgesToUpdate: EdgeToUpdate[] = []
272
-
273
  // Process all edges connected to this node
274
  sigmaGraph.forEachEdge(nodeId, (edge, attributes, source, target) => {
275
  const otherNode = source === nodeId ? target : source
276
  const isOutgoing = source === nodeId
277
-
278
  // Get original edge dynamic ID for later reference
279
  const originalEdgeDynamicId = edge
280
  const edgeIndexInRawGraph = rawGraph.edgeDynamicIdMap[originalEdgeDynamicId]
281
-
282
  // Create new edge with updated node reference
283
  const newEdgeId = sigmaGraph.addEdge(
284
  isOutgoing ? newValue : otherNode,
285
  isOutgoing ? otherNode : newValue,
286
  attributes
287
  )
288
-
289
  // Track edges that need updating in the raw graph
290
  if (edgeIndexInRawGraph !== undefined) {
291
  edgesToUpdate.push({
@@ -294,14 +294,14 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
294
  edgeIndex: edgeIndexInRawGraph
295
  })
296
  }
297
-
298
  // Remove the old edge
299
  sigmaGraph.dropEdge(edge)
300
  })
301
-
302
  // Remove the old node after all edges are processed
303
  sigmaGraph.dropNode(nodeId)
304
-
305
  // Update node reference in raw graph data
306
  const nodeIndex = rawGraph.nodeIdMap[nodeId]
307
  if (nodeIndex !== undefined) {
@@ -311,7 +311,7 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
311
  delete rawGraph.nodeIdMap[nodeId]
312
  rawGraph.nodeIdMap[newValue] = nodeIndex
313
  }
314
-
315
  // Update all edge references in raw graph data
316
  edgesToUpdate.forEach(({ originalDynamicId, newEdgeId, edgeIndex }) => {
317
  if (rawGraph.edges[edgeIndex]) {
@@ -322,14 +322,14 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
322
  if (rawGraph.edges[edgeIndex].target === nodeId) {
323
  rawGraph.edges[edgeIndex].target = newValue
324
  }
325
-
326
  // Update dynamic ID mappings
327
  rawGraph.edges[edgeIndex].dynamicId = newEdgeId
328
  delete rawGraph.edgeDynamicIdMap[originalDynamicId]
329
  rawGraph.edgeDynamicIdMap[newEdgeId] = edgeIndex
330
  }
331
  })
332
-
333
  // Update selected node in store
334
  set({ selectedNode: newValue, moveToSelectedNode: true })
335
  } else {
@@ -342,7 +342,7 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
342
  sigmaGraph.setNodeAttribute(String(nodeId), 'label', newValue)
343
  }
344
  }
345
-
346
  // Trigger a re-render by incrementing the version counter
347
  set((state) => ({ graphDataVersion: state.graphDataVersion + 1 }))
348
  }
@@ -351,17 +351,17 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
351
  throw new Error('Failed to update node in graph')
352
  }
353
  },
354
-
355
  updateEdgeAndSelect: async (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => {
356
  // Get current state
357
  const state = get()
358
  const { sigmaGraph, rawGraph } = state
359
-
360
  // Validate graph state
361
  if (!sigmaGraph || !rawGraph) {
362
  return
363
  }
364
-
365
  try {
366
  const edgeIndex = rawGraph.edgeIdMap[String(edgeId)]
367
  if (edgeIndex !== undefined && rawGraph.edges[edgeIndex]) {
@@ -370,10 +370,10 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
370
  sigmaGraph.setEdgeAttribute(dynamicId, 'label', newValue)
371
  }
372
  }
373
-
374
  // Trigger a re-render by incrementing the version counter
375
  set((state) => ({ graphDataVersion: state.graphDataVersion + 1 }))
376
-
377
  // Update selected edge in store to ensure UI reflects changes
378
  set({ selectedEdge: dynamicId })
379
  } catch (error) {
 
136
  // Version counter to trigger data refresh
137
  graphDataVersion: number
138
  incrementGraphDataVersion: () => void
139
+
140
  // Methods for updating graph elements and UI state together
141
  updateNodeAndSelect: (nodeId: string, entityId: string, propertyName: string, newValue: string) => Promise<void>
142
  updateEdgeAndSelect: (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => Promise<void>
 
252
  // Get current state
253
  const state = get()
254
  const { sigmaGraph, rawGraph } = state
255
+
256
  // Validate graph state
257
  if (!sigmaGraph || !rawGraph || !sigmaGraph.hasNode(nodeId)) {
258
  return
259
  }
260
+
261
  try {
262
  const nodeAttributes = sigmaGraph.getNodeAttributes(nodeId)
263
+
264
  console.log('updateNodeAndSelect', nodeId, entityId, propertyName, newValue)
265
+
266
  // For entity_id changes (node renaming) with NetworkX graph storage
267
  if ((nodeId === entityId) && (propertyName === 'entity_id')) {
268
  // Create new node with updated ID but same attributes
269
  sigmaGraph.addNode(newValue, { ...nodeAttributes, label: newValue })
270
+
271
  const edgesToUpdate: EdgeToUpdate[] = []
272
+
273
  // Process all edges connected to this node
274
  sigmaGraph.forEachEdge(nodeId, (edge, attributes, source, target) => {
275
  const otherNode = source === nodeId ? target : source
276
  const isOutgoing = source === nodeId
277
+
278
  // Get original edge dynamic ID for later reference
279
  const originalEdgeDynamicId = edge
280
  const edgeIndexInRawGraph = rawGraph.edgeDynamicIdMap[originalEdgeDynamicId]
281
+
282
  // Create new edge with updated node reference
283
  const newEdgeId = sigmaGraph.addEdge(
284
  isOutgoing ? newValue : otherNode,
285
  isOutgoing ? otherNode : newValue,
286
  attributes
287
  )
288
+
289
  // Track edges that need updating in the raw graph
290
  if (edgeIndexInRawGraph !== undefined) {
291
  edgesToUpdate.push({
 
294
  edgeIndex: edgeIndexInRawGraph
295
  })
296
  }
297
+
298
  // Remove the old edge
299
  sigmaGraph.dropEdge(edge)
300
  })
301
+
302
  // Remove the old node after all edges are processed
303
  sigmaGraph.dropNode(nodeId)
304
+
305
  // Update node reference in raw graph data
306
  const nodeIndex = rawGraph.nodeIdMap[nodeId]
307
  if (nodeIndex !== undefined) {
 
311
  delete rawGraph.nodeIdMap[nodeId]
312
  rawGraph.nodeIdMap[newValue] = nodeIndex
313
  }
314
+
315
  // Update all edge references in raw graph data
316
  edgesToUpdate.forEach(({ originalDynamicId, newEdgeId, edgeIndex }) => {
317
  if (rawGraph.edges[edgeIndex]) {
 
322
  if (rawGraph.edges[edgeIndex].target === nodeId) {
323
  rawGraph.edges[edgeIndex].target = newValue
324
  }
325
+
326
  // Update dynamic ID mappings
327
  rawGraph.edges[edgeIndex].dynamicId = newEdgeId
328
  delete rawGraph.edgeDynamicIdMap[originalDynamicId]
329
  rawGraph.edgeDynamicIdMap[newEdgeId] = edgeIndex
330
  }
331
  })
332
+
333
  // Update selected node in store
334
  set({ selectedNode: newValue, moveToSelectedNode: true })
335
  } else {
 
342
  sigmaGraph.setNodeAttribute(String(nodeId), 'label', newValue)
343
  }
344
  }
345
+
346
  // Trigger a re-render by incrementing the version counter
347
  set((state) => ({ graphDataVersion: state.graphDataVersion + 1 }))
348
  }
 
351
  throw new Error('Failed to update node in graph')
352
  }
353
  },
354
+
355
  updateEdgeAndSelect: async (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => {
356
  // Get current state
357
  const state = get()
358
  const { sigmaGraph, rawGraph } = state
359
+
360
  // Validate graph state
361
  if (!sigmaGraph || !rawGraph) {
362
  return
363
  }
364
+
365
  try {
366
  const edgeIndex = rawGraph.edgeIdMap[String(edgeId)]
367
  if (edgeIndex !== undefined && rawGraph.edges[edgeIndex]) {
 
370
  sigmaGraph.setEdgeAttribute(dynamicId, 'label', newValue)
371
  }
372
  }
373
+
374
  // Trigger a re-render by incrementing the version counter
375
  set((state) => ({ graphDataVersion: state.graphDataVersion + 1 }))
376
+
377
  // Update selected edge in store to ensure UI reflects changes
378
  set({ selectedEdge: dynamicId })
379
  } catch (error) {
lightrag_webui/src/utils/graphOperations.ts CHANGED
@@ -3,7 +3,7 @@ import { useGraphStore } from '@/stores/graph'
3
  /**
4
  * Update node in the graph visualization
5
  * This function is now a wrapper around the store's updateNodeAndSelect method
6
- *
7
  * @param nodeId - ID of the node to update
8
  * @param entityId - ID of the entity
9
  * @param propertyName - Name of the property being updated
 
3
  /**
4
  * Update node in the graph visualization
5
  * This function is now a wrapper around the store's updateNodeAndSelect method
6
+ *
7
  * @param nodeId - ID of the node to update
8
  * @param entityId - ID of the entity
9
  * @param propertyName - Name of the property being updated
tests/test_graph_storage.py CHANGED
@@ -510,35 +510,66 @@ async def test_graph_batch_operations(storage):
510
  assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
511
  assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
512
  assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
513
- assert nodes_dict[node1_id]["description"] == node1_data["description"], f"{node1_id} 描述不匹配"
514
- assert nodes_dict[node2_id]["description"] == node2_data["description"], f"{node2_id} 描述不匹配"
515
- assert nodes_dict[node3_id]["description"] == node3_data["description"], f"{node3_id} 描述不匹配"
 
 
 
 
 
 
516
 
517
  # 3. 测试 node_degrees_batch - 批量获取多个节点的度数
518
  print("== 测试 node_degrees_batch")
519
  node_degrees = await storage.node_degrees_batch(node_ids)
520
  print(f"批量获取节点度数结果: {node_degrees}")
521
- assert len(node_degrees) == 3, f"应返回3个节点的度数,实际返回 {len(node_degrees)} 个"
 
 
522
  assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
523
  assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
524
  assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
525
- assert node_degrees[node1_id] == 3, f"{node1_id} 度数应为3,实际为 {node_degrees[node1_id]}"
526
- assert node_degrees[node2_id] == 2, f"{node2_id} 度数应为2,实际为 {node_degrees[node2_id]}"
527
- assert node_degrees[node3_id] == 3, f"{node3_id} 度数应为3,实际为 {node_degrees[node3_id]}"
 
 
 
 
 
 
528
 
529
  # 4. 测试 edge_degrees_batch - 批量获取多个边的度数
530
  print("== 测试 edge_degrees_batch")
531
  edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
532
  edge_degrees = await storage.edge_degrees_batch(edges)
533
  print(f"批量获取边度数结果: {edge_degrees}")
534
- assert len(edge_degrees) == 3, f"应返回3条边的度数,实际返回 {len(edge_degrees)} 条"
535
- assert (node1_id, node2_id) in edge_degrees, f"边 {node1_id} -> {node2_id} 应在返回结果中"
536
- assert (node2_id, node3_id) in edge_degrees, f" {node2_id} -> {node3_id} 应在返回结果中"
537
- assert (node3_id, node4_id) in edge_degrees, f"边 {node3_id} -> {node4_id} 应在返回结果中"
 
 
 
 
 
 
 
 
 
 
 
538
  # 验证边的度数是否正确(源节点度数 + 目标节点度数)
539
- assert edge_degrees[(node1_id, node2_id)] == 5, f"边 {node1_id} -> {node2_id} 度数应为5,实际为 {edge_degrees[(node1_id, node2_id)]}"
540
- assert edge_degrees[(node2_id, node3_id)] == 5, f"边 {node2_id} -> {node3_id} 度数应为5,实际为 {edge_degrees[(node2_id, node3_id)]}"
541
- assert edge_degrees[(node3_id, node4_id)] == 5, f"边 {node3_id} -> {node4_id} 度数应为5,实际为 {edge_degrees[(node3_id, node4_id)]}"
 
 
 
 
 
 
542
 
543
  # 5. 测试 get_edges_batch - 批量获取多个边的属性
544
  print("== 测试 get_edges_batch")
@@ -547,28 +578,54 @@ async def test_graph_batch_operations(storage):
547
  edges_dict = await storage.get_edges_batch(edge_dicts)
548
  print(f"批量获取边属性结果: {edges_dict.keys()}")
549
  assert len(edges_dict) == 3, f"应返回3条边的属性,实际返回 {len(edges_dict)} 条"
550
- assert (node1_id, node2_id) in edges_dict, f"边 {node1_id} -> {node2_id} 应在返回结果中"
551
- assert (node2_id, node3_id) in edges_dict, f"边 {node2_id} -> {node3_id} 应在返回结果中"
552
- assert (node3_id, node4_id) in edges_dict, f"边 {node3_id} -> {node4_id} 应在返回结果中"
553
- assert edges_dict[(node1_id, node2_id)]["relationship"] == edge1_data["relationship"], f"边 {node1_id} -> {node2_id} 关系不匹配"
554
- assert edges_dict[(node2_id, node3_id)]["relationship"] == edge2_data["relationship"], f"边 {node2_id} -> {node3_id} 关系不匹配"
555
- assert edges_dict[(node3_id, node4_id)]["relationship"] == edge5_data["relationship"], f"边 {node3_id} -> {node4_id} 关系不匹配"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
 
557
  # 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
558
  print("== 测试 get_nodes_edges_batch")
559
  nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
560
  print(f"批量获取节点边结果: {nodes_edges.keys()}")
561
- assert len(nodes_edges) == 2, f"应返回2个节点的边,实际返回 {len(nodes_edges)} 个"
 
 
562
  assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
563
  assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
564
- assert len(nodes_edges[node1_id]) == 3, f"{node1_id} 应有3条边,实际有 {len(nodes_edges[node1_id])} 条"
565
- assert len(nodes_edges[node3_id]) == 3, f"{node3_id} 应有3条边,实际有 {len(nodes_edges[node3_id])} 条"
 
 
 
 
566
 
567
  # 7. 清理数据
568
  print("== 测试 drop")
569
  result = await storage.drop()
570
  print(f"清理结果: {result}")
571
- assert result["status"] == "success", f"清理应成功,实际状态为 {result['status']}"
 
 
572
 
573
  print("\n批量操作测试完成")
574
  return True
@@ -630,7 +687,7 @@ async def main():
630
  if basic_result:
631
  ASCIIColors.cyan("\n=== 开始高级测试 ===")
632
  advanced_result = await test_graph_advanced(storage)
633
-
634
  if advanced_result:
635
  ASCIIColors.cyan("\n=== 开始批量操作测试 ===")
636
  await test_graph_batch_operations(storage)
 
510
  assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
511
  assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
512
  assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
513
+ assert (
514
+ nodes_dict[node1_id]["description"] == node1_data["description"]
515
+ ), f"{node1_id} 描述不匹配"
516
+ assert (
517
+ nodes_dict[node2_id]["description"] == node2_data["description"]
518
+ ), f"{node2_id} 描述不匹配"
519
+ assert (
520
+ nodes_dict[node3_id]["description"] == node3_data["description"]
521
+ ), f"{node3_id} 描述不匹配"
522
 
523
  # 3. 测试 node_degrees_batch - 批量获取多个节点的度数
524
  print("== 测试 node_degrees_batch")
525
  node_degrees = await storage.node_degrees_batch(node_ids)
526
  print(f"批量获取节点度数结果: {node_degrees}")
527
+ assert (
528
+ len(node_degrees) == 3
529
+ ), f"应返回3个节点的度数,实际返回 {len(node_degrees)} 个"
530
  assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
531
  assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
532
  assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
533
+ assert (
534
+ node_degrees[node1_id] == 3
535
+ ), f"{node1_id} 度数应为3,实际为 {node_degrees[node1_id]}"
536
+ assert (
537
+ node_degrees[node2_id] == 2
538
+ ), f"{node2_id} 度数应为2,实际为 {node_degrees[node2_id]}"
539
+ assert (
540
+ node_degrees[node3_id] == 3
541
+ ), f"{node3_id} 度数应为3,实际为 {node_degrees[node3_id]}"
542
 
543
  # 4. 测试 edge_degrees_batch - 批量获取多个边的度数
544
  print("== 测试 edge_degrees_batch")
545
  edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
546
  edge_degrees = await storage.edge_degrees_batch(edges)
547
  print(f"批量获取边度数结果: {edge_degrees}")
548
+ assert (
549
+ len(edge_degrees) == 3
550
+ ), f"应返回3条边的度数,实际返回 {len(edge_degrees)} "
551
+ assert (
552
+ node1_id,
553
+ node2_id,
554
+ ) in edge_degrees, f"边 {node1_id} -> {node2_id} 应在返回结果中"
555
+ assert (
556
+ node2_id,
557
+ node3_id,
558
+ ) in edge_degrees, f"边 {node2_id} -> {node3_id} 应在返回结果中"
559
+ assert (
560
+ node3_id,
561
+ node4_id,
562
+ ) in edge_degrees, f"边 {node3_id} -> {node4_id} 应在返回结果中"
563
  # 验证边的度数是否正确(源节点度数 + 目标节点度数)
564
+ assert (
565
+ edge_degrees[(node1_id, node2_id)] == 5
566
+ ), f"边 {node1_id} -> {node2_id} 度数应为5,实际为 {edge_degrees[(node1_id, node2_id)]}"
567
+ assert (
568
+ edge_degrees[(node2_id, node3_id)] == 5
569
+ ), f"边 {node2_id} -> {node3_id} 度数应为5,实际为 {edge_degrees[(node2_id, node3_id)]}"
570
+ assert (
571
+ edge_degrees[(node3_id, node4_id)] == 5
572
+ ), f"边 {node3_id} -> {node4_id} 度数应为5,实际为 {edge_degrees[(node3_id, node4_id)]}"
573
 
574
  # 5. 测试 get_edges_batch - 批量获取多个边的属性
575
  print("== 测试 get_edges_batch")
 
578
  edges_dict = await storage.get_edges_batch(edge_dicts)
579
  print(f"批量获取边属性结果: {edges_dict.keys()}")
580
  assert len(edges_dict) == 3, f"应返回3条边的属性,实际返回 {len(edges_dict)} 条"
581
+ assert (
582
+ node1_id,
583
+ node2_id,
584
+ ) in edges_dict, f"边 {node1_id} -> {node2_id} 应在返回结果中"
585
+ assert (
586
+ node2_id,
587
+ node3_id,
588
+ ) in edges_dict, f"边 {node2_id} -> {node3_id} 应在返回结果中"
589
+ assert (
590
+ node3_id,
591
+ node4_id,
592
+ ) in edges_dict, f"边 {node3_id} -> {node4_id} 应在返回结果中"
593
+ assert (
594
+ edges_dict[(node1_id, node2_id)]["relationship"]
595
+ == edge1_data["relationship"]
596
+ ), f"边 {node1_id} -> {node2_id} 关系不匹配"
597
+ assert (
598
+ edges_dict[(node2_id, node3_id)]["relationship"]
599
+ == edge2_data["relationship"]
600
+ ), f"边 {node2_id} -> {node3_id} 关系不匹配"
601
+ assert (
602
+ edges_dict[(node3_id, node4_id)]["relationship"]
603
+ == edge5_data["relationship"]
604
+ ), f"边 {node3_id} -> {node4_id} 关系不匹配"
605
 
606
  # 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
607
  print("== 测试 get_nodes_edges_batch")
608
  nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
609
  print(f"批量获取节点边结果: {nodes_edges.keys()}")
610
+ assert (
611
+ len(nodes_edges) == 2
612
+ ), f"应返回2个节点的边,实际返回 {len(nodes_edges)} 个"
613
  assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
614
  assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
615
+ assert (
616
+ len(nodes_edges[node1_id]) == 3
617
+ ), f"{node1_id} 应有3条边,实际有 {len(nodes_edges[node1_id])} 条"
618
+ assert (
619
+ len(nodes_edges[node3_id]) == 3
620
+ ), f"{node3_id} 应有3条边,实际有 {len(nodes_edges[node3_id])} 条"
621
 
622
  # 7. 清理数据
623
  print("== 测试 drop")
624
  result = await storage.drop()
625
  print(f"清理结果: {result}")
626
+ assert (
627
+ result["status"] == "success"
628
+ ), f"清理应成功,实际状态为 {result['status']}"
629
 
630
  print("\n批量操作测试完成")
631
  return True
 
687
  if basic_result:
688
  ASCIIColors.cyan("\n=== 开始高级测试 ===")
689
  advanced_result = await test_graph_advanced(storage)
690
+
691
  if advanced_result:
692
  ASCIIColors.cyan("\n=== 开始批量操作测试 ===")
693
  await test_graph_batch_operations(storage)