yangdx commited on
Commit
ec8fba9
·
1 Parent(s): 073182d

Implement batch query funtions for PGGraphStorage of PostgreSQl AGE graph storage

Browse files
Files changed (1) hide show
  1. lightrag/kg/postgres_impl.py +191 -0
lightrag/kg/postgres_impl.py CHANGED
@@ -1458,6 +1458,197 @@ class PGGraphStorage(BaseGraphStorage):
1458
  logger.error(f"Error during edge deletion: {str(e)}")
1459
  raise
1460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1461
  async def get_all_labels(self) -> list[str]:
1462
  """
1463
  Get all labels (node IDs) in the graph.
 
1458
  logger.error(f"Error during edge deletion: {str(e)}")
1459
  raise
1460
 
1461
+ async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
1462
+ """
1463
+ Retrieve multiple nodes in one query using UNWIND.
1464
+
1465
+ Args:
1466
+ node_ids: List of node entity IDs to fetch.
1467
+
1468
+ Returns:
1469
+ A dictionary mapping each node_id to its node data (or None if not found).
1470
+ """
1471
+ if not node_ids:
1472
+ return {}
1473
+
1474
+ # Format node IDs for the query
1475
+ formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
1476
+
1477
+ query = """SELECT * FROM cypher('%s', $$
1478
+ UNWIND [%s] AS node_id
1479
+ MATCH (n:base {entity_id: node_id})
1480
+ RETURN node_id, n
1481
+ $$) AS (node_id text, n agtype)""" % (
1482
+ self.graph_name,
1483
+ formatted_ids
1484
+ )
1485
+
1486
+ results = await self._query(query)
1487
+
1488
+ # Build result dictionary
1489
+ nodes_dict = {}
1490
+ for result in results:
1491
+ if result["node_id"] and result["n"]:
1492
+ node_dict = result["n"]["properties"]
1493
+ # Remove the 'base' label if present in a 'labels' property
1494
+ if "labels" in node_dict:
1495
+ node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
1496
+ nodes_dict[result["node_id"]] = node_dict
1497
+
1498
+ return nodes_dict
1499
+
1500
+ async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
1501
+ """
1502
+ Retrieve the degree for multiple nodes in a single query using UNWIND.
1503
+
1504
+ Args:
1505
+ node_ids: List of node labels (entity_id values) to look up.
1506
+
1507
+ Returns:
1508
+ A dictionary mapping each node_id to its degree (number of relationships).
1509
+ If a node is not found, its degree will be set to 0.
1510
+ """
1511
+ if not node_ids:
1512
+ return {}
1513
+
1514
+ # Format node IDs for the query
1515
+ formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
1516
+
1517
+ query = """SELECT * FROM cypher('%s', $$
1518
+ UNWIND [%s] AS node_id
1519
+ MATCH (n:base {entity_id: node_id})
1520
+ OPTIONAL MATCH (n)-[r]-()
1521
+ RETURN node_id, count(r) AS degree
1522
+ $$) AS (node_id text, degree bigint)""" % (
1523
+ self.graph_name,
1524
+ formatted_ids
1525
+ )
1526
+
1527
+ results = await self._query(query)
1528
+
1529
+ # Build result dictionary
1530
+ degrees_dict = {}
1531
+ for result in results:
1532
+ if result["node_id"] is not None:
1533
+ degrees_dict[result["node_id"]] = int(result["degree"])
1534
+
1535
+ # Ensure all requested node_ids are in the result dictionary
1536
+ for node_id in node_ids:
1537
+ if node_id not in degrees_dict:
1538
+ degrees_dict[node_id] = 0
1539
+
1540
+ return degrees_dict
1541
+
1542
+ async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
1543
+ """
1544
+ Calculate the combined degree for each edge (sum of the source and target node degrees)
1545
+ in batch using the already implemented node_degrees_batch.
1546
+
1547
+ Args:
1548
+ edges: List of (source_node_id, target_node_id) tuples
1549
+
1550
+ Returns:
1551
+ Dictionary mapping edge tuples to their combined degrees
1552
+ """
1553
+ if not edges:
1554
+ return {}
1555
+
1556
+ # Use node_degrees_batch to get all node degrees efficiently
1557
+ all_nodes = set()
1558
+ for src, tgt in edges:
1559
+ all_nodes.add(src)
1560
+ all_nodes.add(tgt)
1561
+
1562
+ node_degrees = await self.node_degrees_batch(list(all_nodes))
1563
+
1564
+ # Calculate edge degrees
1565
+ edge_degrees_dict = {}
1566
+ for src, tgt in edges:
1567
+ src_degree = node_degrees.get(src, 0)
1568
+ tgt_degree = node_degrees.get(tgt, 0)
1569
+ edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
1570
+
1571
+ return edge_degrees_dict
1572
+
1573
+ async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
1574
+ """
1575
+ Retrieve edge properties for multiple (src, tgt) pairs in one query.
1576
+
1577
+ Args:
1578
+ pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
1579
+
1580
+ Returns:
1581
+ A dictionary mapping (src, tgt) tuples to their edge properties.
1582
+ """
1583
+ if not pairs:
1584
+ return {}
1585
+
1586
+ # 从字典列表中提取源节点和目标节点ID
1587
+ src_nodes = []
1588
+ tgt_nodes = []
1589
+ for pair in pairs:
1590
+ src_nodes.append(pair["src"].replace('"', ''))
1591
+ tgt_nodes.append(pair["tgt"].replace('"', ''))
1592
+
1593
+ # 构建查询,使用数组索引来匹配源节点和目标节点
1594
+ src_array = ", ".join([f'"{src}"' for src in src_nodes])
1595
+ tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
1596
+
1597
+ query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1598
+ WITH [{src_array}] AS sources, [{tgt_array}] AS targets
1599
+ UNWIND range(0, size(sources)-1) AS i
1600
+ MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
1601
+ RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
1602
+ $$) AS (source text, target text, edge_properties agtype)"""
1603
+
1604
+ results = await self._query(query)
1605
+
1606
+ # 构建结果字典
1607
+ edges_dict = {}
1608
+ for result in results:
1609
+ if result["source"] and result["target"] and result["edge_properties"]:
1610
+ edges_dict[(result["source"], result["target"])] = result["edge_properties"]
1611
+
1612
+ return edges_dict
1613
+
1614
+ async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
1615
+ """
1616
+ Get all edges for multiple nodes in a single batch operation.
1617
+
1618
+ Args:
1619
+ node_ids: List of node IDs to get edges for
1620
+
1621
+ Returns:
1622
+ Dictionary mapping node IDs to lists of (source, target) edge tuples
1623
+ """
1624
+ if not node_ids:
1625
+ return {}
1626
+
1627
+ # Format node IDs for the query
1628
+ formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
1629
+
1630
+ query = """SELECT * FROM cypher('%s', $$
1631
+ UNWIND [%s] AS node_id
1632
+ MATCH (n:base {entity_id: node_id})
1633
+ OPTIONAL MATCH (n)-[]-(connected:base)
1634
+ RETURN node_id, connected.entity_id AS connected_id
1635
+ $$) AS (node_id text, connected_id text)""" % (
1636
+ self.graph_name,
1637
+ formatted_ids
1638
+ )
1639
+
1640
+ results = await self._query(query)
1641
+
1642
+ # Build result dictionary
1643
+ nodes_edges_dict = {node_id: [] for node_id in node_ids}
1644
+ for result in results:
1645
+ if result["node_id"] and result["connected_id"]:
1646
+ nodes_edges_dict[result["node_id"]].append(
1647
+ (result["node_id"], result["connected_id"])
1648
+ )
1649
+
1650
+ return nodes_edges_dict
1651
+
1652
  async def get_all_labels(self) -> list[str]:
1653
  """
1654
  Get all labels (node IDs) in the graph.