yangdx
commited on
Commit
·
32943a3
1
Parent(s):
27dc94d
Optimize PostgreSQL AGE graph storage performance by eperate forward and backward edge query
Browse files- lightrag/kg/postgres_impl.py +75 -33
lightrag/kg/postgres_impl.py
CHANGED
@@ -1170,9 +1170,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1170 |
Returns:
|
1171 |
list[dict[str, Any]]: a list of dictionaries containing the result set
|
1172 |
"""
|
1173 |
-
|
1174 |
-
logger.info(f"Executing graph query: {query}")
|
1175 |
-
|
1176 |
try:
|
1177 |
if readonly:
|
1178 |
data = await self.db.query(
|
@@ -1255,8 +1252,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1255 |
label = node_id.strip('"')
|
1256 |
|
1257 |
query = """SELECT * FROM cypher('%s', $$
|
1258 |
-
MATCH (n:base {entity_id: "%s"})-[]-(
|
1259 |
-
RETURN count(
|
1260 |
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
1261 |
record = (await self._query(query))[0]
|
1262 |
if record:
|
@@ -1523,12 +1520,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1523 |
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
1524 |
"""
|
1525 |
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
|
|
|
|
1526 |
|
1527 |
Args:
|
1528 |
node_ids: List of node labels (entity_id values) to look up.
|
1529 |
|
1530 |
Returns:
|
1531 |
-
A dictionary mapping each node_id to its degree (number of relationships).
|
1532 |
If a node is not found, its degree will be set to 0.
|
1533 |
"""
|
1534 |
if not node_ids:
|
@@ -1539,28 +1538,45 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1539 |
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
1540 |
)
|
1541 |
|
1542 |
-
|
1543 |
UNWIND [%s] AS node_id
|
1544 |
MATCH (n:base {entity_id: node_id})
|
1545 |
-
OPTIONAL MATCH (n)-[r]->()
|
1546 |
-
RETURN node_id, count(
|
1547 |
-
$$) AS (node_id text,
|
1548 |
self.graph_name,
|
1549 |
formatted_ids,
|
1550 |
)
|
1551 |
|
1552 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1553 |
|
1554 |
-
|
1555 |
-
|
1556 |
-
for result in results:
|
1557 |
-
if result["node_id"] is not None:
|
1558 |
-
degrees_dict[result["node_id"]] = int(result["degree"])
|
1559 |
|
1560 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1561 |
for node_id in node_ids:
|
1562 |
-
|
1563 |
-
|
|
|
1564 |
|
1565 |
return degrees_dict
|
1566 |
|
@@ -1602,6 +1618,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1602 |
) -> dict[tuple[str, str], dict]:
|
1603 |
"""
|
1604 |
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
|
|
1605 |
|
1606 |
Args:
|
1607 |
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
@@ -1612,33 +1629,41 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1612 |
if not pairs:
|
1613 |
return {}
|
1614 |
|
1615 |
-
# 从字典列表中提取源节点和目标节点ID
|
1616 |
src_nodes = []
|
1617 |
tgt_nodes = []
|
1618 |
for pair in pairs:
|
1619 |
src_nodes.append(pair["src"].replace('"', ""))
|
1620 |
tgt_nodes.append(pair["tgt"].replace('"', ""))
|
1621 |
|
1622 |
-
# 构建查询,使用数组索引来匹配源节点和目标节点
|
1623 |
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
1624 |
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
1625 |
|
1626 |
-
|
1627 |
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
1628 |
UNWIND range(0, size(sources)-1) AS i
|
1629 |
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]->(b:base {{entity_id: targets[i]}})
|
1630 |
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
1631 |
$$) AS (source text, target text, edge_properties agtype)"""
|
1632 |
|
1633 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1634 |
|
1635 |
-
# 构建结果字典
|
1636 |
edges_dict = {}
|
1637 |
-
|
|
|
|
|
|
|
|
|
|
|
1638 |
if result["source"] and result["target"] and result["edge_properties"]:
|
1639 |
-
edges_dict[(result["source"], result["target"])] = result[
|
1640 |
-
"edge_properties"
|
1641 |
-
]
|
1642 |
|
1643 |
return edges_dict
|
1644 |
|
@@ -1646,7 +1671,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1646 |
self, node_ids: list[str]
|
1647 |
) -> dict[str, list[tuple[str, str]]]:
|
1648 |
"""
|
1649 |
-
Get all edges for multiple nodes in a single batch operation.
|
1650 |
|
1651 |
Args:
|
1652 |
node_ids: List of node IDs to get edges for
|
@@ -1662,7 +1687,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1662 |
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
1663 |
)
|
1664 |
|
1665 |
-
|
1666 |
UNWIND [%s] AS node_id
|
1667 |
MATCH (n:base {entity_id: node_id})
|
1668 |
OPTIONAL MATCH (n:base)-[]->(connected:base)
|
@@ -1672,15 +1697,32 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1672 |
formatted_ids,
|
1673 |
)
|
1674 |
|
1675 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1676 |
|
1677 |
-
# Build result dictionary
|
1678 |
nodes_edges_dict = {node_id: [] for node_id in node_ids}
|
1679 |
-
|
|
|
1680 |
if result["node_id"] and result["connected_id"]:
|
1681 |
nodes_edges_dict[result["node_id"]].append(
|
1682 |
(result["node_id"], result["connected_id"])
|
1683 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
1684 |
|
1685 |
return nodes_edges_dict
|
1686 |
|
|
|
1170 |
Returns:
|
1171 |
list[dict[str, Any]]: a list of dictionaries containing the result set
|
1172 |
"""
|
|
|
|
|
|
|
1173 |
try:
|
1174 |
if readonly:
|
1175 |
data = await self.db.query(
|
|
|
1252 |
label = node_id.strip('"')
|
1253 |
|
1254 |
query = """SELECT * FROM cypher('%s', $$
|
1255 |
+
MATCH (n:base {entity_id: "%s"})-[r]-()
|
1256 |
+
RETURN count(r) AS total_edge_count
|
1257 |
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
1258 |
record = (await self._query(query))[0]
|
1259 |
if record:
|
|
|
1520 |
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
1521 |
"""
|
1522 |
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
1523 |
+
Calculates the total degree by counting distinct relationships.
|
1524 |
+
Uses separate queries for outgoing and incoming edges.
|
1525 |
|
1526 |
Args:
|
1527 |
node_ids: List of node labels (entity_id values) to look up.
|
1528 |
|
1529 |
Returns:
|
1530 |
+
A dictionary mapping each node_id to its degree (total number of relationships).
|
1531 |
If a node is not found, its degree will be set to 0.
|
1532 |
"""
|
1533 |
if not node_ids:
|
|
|
1538 |
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
1539 |
)
|
1540 |
|
1541 |
+
outgoing_query = """SELECT * FROM cypher('%s', $$
|
1542 |
UNWIND [%s] AS node_id
|
1543 |
MATCH (n:base {entity_id: node_id})
|
1544 |
+
OPTIONAL MATCH (n)-[r]->(a)
|
1545 |
+
RETURN node_id, count(a) AS out_degree
|
1546 |
+
$$) AS (node_id text, out_degree bigint)""" % (
|
1547 |
self.graph_name,
|
1548 |
formatted_ids,
|
1549 |
)
|
1550 |
|
1551 |
+
incoming_query = """SELECT * FROM cypher('%s', $$
|
1552 |
+
UNWIND [%s] AS node_id
|
1553 |
+
MATCH (n:base {entity_id: node_id})
|
1554 |
+
OPTIONAL MATCH (n)<-[r]-(b)
|
1555 |
+
RETURN node_id, count(b) AS in_degree
|
1556 |
+
$$) AS (node_id text, in_degree bigint)""" % (
|
1557 |
+
self.graph_name,
|
1558 |
+
formatted_ids,
|
1559 |
+
)
|
1560 |
|
1561 |
+
outgoing_results = await self._query(outgoing_query)
|
1562 |
+
incoming_results = await self._query(incoming_query)
|
|
|
|
|
|
|
1563 |
|
1564 |
+
out_degrees = {}
|
1565 |
+
in_degrees = {}
|
1566 |
+
|
1567 |
+
for result in outgoing_results:
|
1568 |
+
if result["node_id"] is not None:
|
1569 |
+
out_degrees[result["node_id"]] = int(result["out_degree"])
|
1570 |
+
|
1571 |
+
for result in incoming_results:
|
1572 |
+
if result["node_id"] is not None:
|
1573 |
+
in_degrees[result["node_id"]] = int(result["in_degree"])
|
1574 |
+
|
1575 |
+
degrees_dict = {}
|
1576 |
for node_id in node_ids:
|
1577 |
+
out_degree = out_degrees.get(node_id, 0)
|
1578 |
+
in_degree = in_degrees.get(node_id, 0)
|
1579 |
+
degrees_dict[node_id] = out_degree + in_degree
|
1580 |
|
1581 |
return degrees_dict
|
1582 |
|
|
|
1618 |
) -> dict[tuple[str, str], dict]:
|
1619 |
"""
|
1620 |
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
1621 |
+
Get forward and backward edges seperately and merge them before return
|
1622 |
|
1623 |
Args:
|
1624 |
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
|
|
1629 |
if not pairs:
|
1630 |
return {}
|
1631 |
|
|
|
1632 |
src_nodes = []
|
1633 |
tgt_nodes = []
|
1634 |
for pair in pairs:
|
1635 |
src_nodes.append(pair["src"].replace('"', ""))
|
1636 |
tgt_nodes.append(pair["tgt"].replace('"', ""))
|
1637 |
|
|
|
1638 |
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
1639 |
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
1640 |
|
1641 |
+
forward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1642 |
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
1643 |
UNWIND range(0, size(sources)-1) AS i
|
1644 |
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]->(b:base {{entity_id: targets[i]}})
|
1645 |
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
1646 |
$$) AS (source text, target text, edge_properties agtype)"""
|
1647 |
|
1648 |
+
backward_query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1649 |
+
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
1650 |
+
UNWIND range(0, size(sources)-1) AS i
|
1651 |
+
MATCH (a:base {{entity_id: sources[i]}})<-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
|
1652 |
+
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
1653 |
+
$$) AS (source text, target text, edge_properties agtype)"""
|
1654 |
+
|
1655 |
+
forward_results = await self._query(forward_query)
|
1656 |
+
backward_results = await self._query(backward_query)
|
1657 |
|
|
|
1658 |
edges_dict = {}
|
1659 |
+
|
1660 |
+
for result in forward_results:
|
1661 |
+
if result["source"] and result["target"] and result["edge_properties"]:
|
1662 |
+
edges_dict[(result["source"], result["target"])] = result["edge_properties"]
|
1663 |
+
|
1664 |
+
for result in backward_results:
|
1665 |
if result["source"] and result["target"] and result["edge_properties"]:
|
1666 |
+
edges_dict[(result["source"], result["target"])] = result["edge_properties"]
|
|
|
|
|
1667 |
|
1668 |
return edges_dict
|
1669 |
|
|
|
1671 |
self, node_ids: list[str]
|
1672 |
) -> dict[str, list[tuple[str, str]]]:
|
1673 |
"""
|
1674 |
+
Get all edges (both outgoing and incoming) for multiple nodes in a single batch operation.
|
1675 |
|
1676 |
Args:
|
1677 |
node_ids: List of node IDs to get edges for
|
|
|
1687 |
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
1688 |
)
|
1689 |
|
1690 |
+
outgoing_query = """SELECT * FROM cypher('%s', $$
|
1691 |
UNWIND [%s] AS node_id
|
1692 |
MATCH (n:base {entity_id: node_id})
|
1693 |
OPTIONAL MATCH (n:base)-[]->(connected:base)
|
|
|
1697 |
formatted_ids,
|
1698 |
)
|
1699 |
|
1700 |
+
incoming_query = """SELECT * FROM cypher('%s', $$
|
1701 |
+
UNWIND [%s] AS node_id
|
1702 |
+
MATCH (n:base {entity_id: node_id})
|
1703 |
+
OPTIONAL MATCH (n:base)<-[]-(connected:base)
|
1704 |
+
RETURN node_id, connected.entity_id AS connected_id
|
1705 |
+
$$) AS (node_id text, connected_id text)""" % (
|
1706 |
+
self.graph_name,
|
1707 |
+
formatted_ids,
|
1708 |
+
)
|
1709 |
+
|
1710 |
+
outgoing_results = await self._query(outgoing_query)
|
1711 |
+
incoming_results = await self._query(incoming_query)
|
1712 |
|
|
|
1713 |
nodes_edges_dict = {node_id: [] for node_id in node_ids}
|
1714 |
+
|
1715 |
+
for result in outgoing_results:
|
1716 |
if result["node_id"] and result["connected_id"]:
|
1717 |
nodes_edges_dict[result["node_id"]].append(
|
1718 |
(result["node_id"], result["connected_id"])
|
1719 |
)
|
1720 |
+
|
1721 |
+
for result in incoming_results:
|
1722 |
+
if result["node_id"] and result["connected_id"]:
|
1723 |
+
nodes_edges_dict[result["node_id"]].append(
|
1724 |
+
(result["connected_id"], result["node_id"])
|
1725 |
+
)
|
1726 |
|
1727 |
return nodes_edges_dict
|
1728 |
|