yangdx commited on
Commit
32943a3
·
1 Parent(s): 27dc94d

Optimize PostgreSQL AGE graph storage performance by eperate forward and backward edge query

Browse files
Files changed (1) hide show
  1. lightrag/kg/postgres_impl.py +75 -33
lightrag/kg/postgres_impl.py CHANGED
@@ -1170,9 +1170,6 @@ class PGGraphStorage(BaseGraphStorage):
1170
  Returns:
1171
  list[dict[str, Any]]: a list of dictionaries containing the result set
1172
  """
1173
-
1174
- logger.info(f"Executing graph query: {query}")
1175
-
1176
  try:
1177
  if readonly:
1178
  data = await self.db.query(
@@ -1255,8 +1252,8 @@ class PGGraphStorage(BaseGraphStorage):
1255
  label = node_id.strip('"')
1256
 
1257
  query = """SELECT * FROM cypher('%s', $$
1258
- MATCH (n:base {entity_id: "%s"})-[]-(x)
1259
- RETURN count(x) AS total_edge_count
1260
  $$) AS (total_edge_count integer)""" % (self.graph_name, label)
1261
  record = (await self._query(query))[0]
1262
  if record:
@@ -1523,12 +1520,14 @@ class PGGraphStorage(BaseGraphStorage):
1523
  async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
1524
  """
1525
  Retrieve the degree for multiple nodes in a single query using UNWIND.
 
 
1526
 
1527
  Args:
1528
  node_ids: List of node labels (entity_id values) to look up.
1529
 
1530
  Returns:
1531
- A dictionary mapping each node_id to its degree (number of relationships).
1532
  If a node is not found, its degree will be set to 0.
1533
  """
1534
  if not node_ids:
@@ -1539,28 +1538,45 @@ class PGGraphStorage(BaseGraphStorage):
1539
  ['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
1540
  )
1541
 
1542
- query = """SELECT * FROM cypher('%s', $$
1543
  UNWIND [%s] AS node_id
1544
  MATCH (n:base {entity_id: node_id})
1545
- OPTIONAL MATCH (n)-[r]->()
1546
- RETURN node_id, count(r) AS degree
1547
- $$) AS (node_id text, degree bigint)""" % (
1548
  self.graph_name,
1549
  formatted_ids,
1550
  )
1551
 
1552
- results = await self._query(query)
 
 
 
 
 
 
 
 
1553
 
1554
- # Build result dictionary
1555
- degrees_dict = {}
1556
- for result in results:
1557
- if result["node_id"] is not None:
1558
- degrees_dict[result["node_id"]] = int(result["degree"])
1559
 
1560
- # Ensure all requested node_ids are in the result dictionary
 
 
 
 
 
 
 
 
 
 
 
1561
  for node_id in node_ids:
1562
- if node_id not in degrees_dict:
1563
- degrees_dict[node_id] = 0
 
1564
 
1565
  return degrees_dict
1566
 
@@ -1602,6 +1618,7 @@ class PGGraphStorage(BaseGraphStorage):
1602
  ) -> dict[tuple[str, str], dict]:
1603
  """
1604
  Retrieve edge properties for multiple (src, tgt) pairs in one query.
 
1605
 
1606
  Args:
1607
  pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
@@ -1612,33 +1629,41 @@ class PGGraphStorage(BaseGraphStorage):
1612
  if not pairs:
1613
  return {}
1614
 
1615
- # 从字典列表中提取源节点和目标节点ID
1616
  src_nodes = []
1617
  tgt_nodes = []
1618
  for pair in pairs:
1619
  src_nodes.append(pair["src"].replace('"', ""))
1620
  tgt_nodes.append(pair["tgt"].replace('"', ""))
1621
 
1622
- # 构建查询,使用数组索引来匹配源节点和目标节点
1623
  src_array = ", ".join([f'"{src}"' for src in src_nodes])
1624
  tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
1625
 
1626
- query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1627
  WITH [{src_array}] AS sources, [{tgt_array}] AS targets
1628
  UNWIND range(0, size(sources)-1) AS i
1629
  MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]->(b:base {{entity_id: targets[i]}})
1630
  RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
1631
  $$) AS (source text, target text, edge_properties agtype)"""
1632
 
1633
- results = await self._query(query)
 
 
 
 
 
 
 
 
1634
 
1635
- # 构建结果字典
1636
  edges_dict = {}
1637
- for result in results:
 
 
 
 
 
1638
  if result["source"] and result["target"] and result["edge_properties"]:
1639
- edges_dict[(result["source"], result["target"])] = result[
1640
- "edge_properties"
1641
- ]
1642
 
1643
  return edges_dict
1644
 
@@ -1646,7 +1671,7 @@ class PGGraphStorage(BaseGraphStorage):
1646
  self, node_ids: list[str]
1647
  ) -> dict[str, list[tuple[str, str]]]:
1648
  """
1649
- Get all edges for multiple nodes in a single batch operation.
1650
 
1651
  Args:
1652
  node_ids: List of node IDs to get edges for
@@ -1662,7 +1687,7 @@ class PGGraphStorage(BaseGraphStorage):
1662
  ['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
1663
  )
1664
 
1665
- query = """SELECT * FROM cypher('%s', $$
1666
  UNWIND [%s] AS node_id
1667
  MATCH (n:base {entity_id: node_id})
1668
  OPTIONAL MATCH (n:base)-[]->(connected:base)
@@ -1672,15 +1697,32 @@ class PGGraphStorage(BaseGraphStorage):
1672
  formatted_ids,
1673
  )
1674
 
1675
- results = await self._query(query)
 
 
 
 
 
 
 
 
 
 
 
1676
 
1677
- # Build result dictionary
1678
  nodes_edges_dict = {node_id: [] for node_id in node_ids}
1679
- for result in results:
 
1680
  if result["node_id"] and result["connected_id"]:
1681
  nodes_edges_dict[result["node_id"]].append(
1682
  (result["node_id"], result["connected_id"])
1683
  )
 
 
 
 
 
 
1684
 
1685
  return nodes_edges_dict
1686
 
 
1170
  Returns:
1171
  list[dict[str, Any]]: a list of dictionaries containing the result set
1172
  """
 
 
 
1173
  try:
1174
  if readonly:
1175
  data = await self.db.query(
 
1252
  label = node_id.strip('"')
1253
 
1254
  query = """SELECT * FROM cypher('%s', $$
1255
+ MATCH (n:base {entity_id: "%s"})-[r]-()
1256
+ RETURN count(r) AS total_edge_count
1257
  $$) AS (total_edge_count integer)""" % (self.graph_name, label)
1258
  record = (await self._query(query))[0]
1259
  if record:
 
1520
  async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
1521
  """
1522
  Retrieve the degree for multiple nodes in a single query using UNWIND.
1523
+ Calculates the total degree by counting distinct relationships.
1524
+ Uses separate queries for outgoing and incoming edges.
1525
 
1526
  Args:
1527
  node_ids: List of node labels (entity_id values) to look up.
1528
 
1529
  Returns:
1530
+ A dictionary mapping each node_id to its degree (total number of relationships).
1531
  If a node is not found, its degree will be set to 0.
1532
  """
1533
  if not node_ids:
 
1538
  ['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
1539
  )
1540
 
1541
+ outgoing_query = """SELECT * FROM cypher('%s', $$
1542
  UNWIND [%s] AS node_id
1543
  MATCH (n:base {entity_id: node_id})
1544
+ OPTIONAL MATCH (n)-[r]->(a)
1545
+ RETURN node_id, count(a) AS out_degree
1546
+ $$) AS (node_id text, out_degree bigint)""" % (
1547
  self.graph_name,
1548
  formatted_ids,
1549
  )
1550
 
1551
+ incoming_query = """SELECT * FROM cypher('%s', $$
1552
+ UNWIND [%s] AS node_id
1553
+ MATCH (n:base {entity_id: node_id})
1554
+ OPTIONAL MATCH (n)<-[r]-(b)
1555
+ RETURN node_id, count(b) AS in_degree
1556
+ $$) AS (node_id text, in_degree bigint)""" % (
1557
+ self.graph_name,
1558
+ formatted_ids,
1559
+ )
1560
 
1561
+ outgoing_results = await self._query(outgoing_query)
1562
+ incoming_results = await self._query(incoming_query)
 
 
 
1563
 
1564
+ out_degrees = {}
1565
+ in_degrees = {}
1566
+
1567
+ for result in outgoing_results:
1568
+ if result["node_id"] is not None:
1569
+ out_degrees[result["node_id"]] = int(result["out_degree"])
1570
+
1571
+ for result in incoming_results:
1572
+ if result["node_id"] is not None:
1573
+ in_degrees[result["node_id"]] = int(result["in_degree"])
1574
+
1575
+ degrees_dict = {}
1576
  for node_id in node_ids:
1577
+ out_degree = out_degrees.get(node_id, 0)
1578
+ in_degree = in_degrees.get(node_id, 0)
1579
+ degrees_dict[node_id] = out_degree + in_degree
1580
 
1581
  return degrees_dict
1582
 
 
1618
  ) -> dict[tuple[str, str], dict]:
1619
  """
1620
  Retrieve edge properties for multiple (src, tgt) pairs in one query.
1621
+ Get forward and backward edges seperately and merge them before return
1622
 
1623
  Args:
1624
  pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
 
1629
  if not pairs:
1630
  return {}
1631
 
 
1632
  src_nodes = []
1633
  tgt_nodes = []
1634
  for pair in pairs:
1635
  src_nodes.append(pair["src"].replace('"', ""))
1636
  tgt_nodes.append(pair["tgt"].replace('"', ""))
1637
 
 
1638
  src_array = ", ".join([f'"{src}"' for src in src_nodes])
1639
  tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
1640
 
1641
+ forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1642
  WITH [{src_array}] AS sources, [{tgt_array}] AS targets
1643
  UNWIND range(0, size(sources)-1) AS i
1644
  MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]->(b:base {{entity_id: targets[i]}})
1645
  RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
1646
  $$) AS (source text, target text, edge_properties agtype)"""
1647
 
1648
+ backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1649
+ WITH [{src_array}] AS sources, [{tgt_array}] AS targets
1650
+ UNWIND range(0, size(sources)-1) AS i
1651
+ MATCH (a:base {{entity_id: sources[i]}})<-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
1652
+ RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
1653
+ $$) AS (source text, target text, edge_properties agtype)"""
1654
+
1655
+ forward_results = await self._query(forward_query)
1656
+ backward_results = await self._query(backward_query)
1657
 
 
1658
  edges_dict = {}
1659
+
1660
+ for result in forward_results:
1661
+ if result["source"] and result["target"] and result["edge_properties"]:
1662
+ edges_dict[(result["source"], result["target"])] = result["edge_properties"]
1663
+
1664
+ for result in backward_results:
1665
  if result["source"] and result["target"] and result["edge_properties"]:
1666
+ edges_dict[(result["source"], result["target"])] = result["edge_properties"]
 
 
1667
 
1668
  return edges_dict
1669
 
 
1671
  self, node_ids: list[str]
1672
  ) -> dict[str, list[tuple[str, str]]]:
1673
  """
1674
+ Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation.
1675
 
1676
  Args:
1677
  node_ids: List of node IDs to get edges for
 
1687
  ['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
1688
  )
1689
 
1690
+ outgoing_query = """SELECT * FROM cypher('%s', $$
1691
  UNWIND [%s] AS node_id
1692
  MATCH (n:base {entity_id: node_id})
1693
  OPTIONAL MATCH (n:base)-[]->(connected:base)
 
1697
  formatted_ids,
1698
  )
1699
 
1700
+ incoming_query = """SELECT * FROM cypher('%s', $$
1701
+ UNWIND [%s] AS node_id
1702
+ MATCH (n:base {entity_id: node_id})
1703
+ OPTIONAL MATCH (n:base)<-[]-(connected:base)
1704
+ RETURN node_id, connected.entity_id AS connected_id
1705
+ $$) AS (node_id text, connected_id text)""" % (
1706
+ self.graph_name,
1707
+ formatted_ids,
1708
+ )
1709
+
1710
+ outgoing_results = await self._query(outgoing_query)
1711
+ incoming_results = await self._query(incoming_query)
1712
 
 
1713
  nodes_edges_dict = {node_id: [] for node_id in node_ids}
1714
+
1715
+ for result in outgoing_results:
1716
  if result["node_id"] and result["connected_id"]:
1717
  nodes_edges_dict[result["node_id"]].append(
1718
  (result["node_id"], result["connected_id"])
1719
  )
1720
+
1721
+ for result in incoming_results:
1722
+ if result["node_id"] and result["connected_id"]:
1723
+ nodes_edges_dict[result["node_id"]].append(
1724
+ (result["connected_id"], result["node_id"])
1725
+ )
1726
 
1727
  return nodes_edges_dict
1728