yangdx
commited on
Commit
·
ec8fba9
1
Parent(s):
073182d
Implement batch query funtions for PGGraphStorage of PostgreSQl AGE graph storage
Browse files- lightrag/kg/postgres_impl.py +191 -0
lightrag/kg/postgres_impl.py
CHANGED
@@ -1458,6 +1458,197 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1458 |
logger.error(f"Error during edge deletion: {str(e)}")
|
1459 |
raise
|
1460 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1461 |
async def get_all_labels(self) -> list[str]:
|
1462 |
"""
|
1463 |
Get all labels (node IDs) in the graph.
|
|
|
1458 |
logger.error(f"Error during edge deletion: {str(e)}")
|
1459 |
raise
|
1460 |
|
1461 |
+
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
1462 |
+
"""
|
1463 |
+
Retrieve multiple nodes in one query using UNWIND.
|
1464 |
+
|
1465 |
+
Args:
|
1466 |
+
node_ids: List of node entity IDs to fetch.
|
1467 |
+
|
1468 |
+
Returns:
|
1469 |
+
A dictionary mapping each node_id to its node data (or None if not found).
|
1470 |
+
"""
|
1471 |
+
if not node_ids:
|
1472 |
+
return {}
|
1473 |
+
|
1474 |
+
# Format node IDs for the query
|
1475 |
+
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
|
1476 |
+
|
1477 |
+
query = """SELECT * FROM cypher('%s', $$
|
1478 |
+
UNWIND [%s] AS node_id
|
1479 |
+
MATCH (n:base {entity_id: node_id})
|
1480 |
+
RETURN node_id, n
|
1481 |
+
$$) AS (node_id text, n agtype)""" % (
|
1482 |
+
self.graph_name,
|
1483 |
+
formatted_ids
|
1484 |
+
)
|
1485 |
+
|
1486 |
+
results = await self._query(query)
|
1487 |
+
|
1488 |
+
# Build result dictionary
|
1489 |
+
nodes_dict = {}
|
1490 |
+
for result in results:
|
1491 |
+
if result["node_id"] and result["n"]:
|
1492 |
+
node_dict = result["n"]["properties"]
|
1493 |
+
# Remove the 'base' label if present in a 'labels' property
|
1494 |
+
if "labels" in node_dict:
|
1495 |
+
node_dict["labels"] = [label for label in node_dict["labels"] if label != "base"]
|
1496 |
+
nodes_dict[result["node_id"]] = node_dict
|
1497 |
+
|
1498 |
+
return nodes_dict
|
1499 |
+
|
1500 |
+
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
1501 |
+
"""
|
1502 |
+
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
1503 |
+
|
1504 |
+
Args:
|
1505 |
+
node_ids: List of node labels (entity_id values) to look up.
|
1506 |
+
|
1507 |
+
Returns:
|
1508 |
+
A dictionary mapping each node_id to its degree (number of relationships).
|
1509 |
+
If a node is not found, its degree will be set to 0.
|
1510 |
+
"""
|
1511 |
+
if not node_ids:
|
1512 |
+
return {}
|
1513 |
+
|
1514 |
+
# Format node IDs for the query
|
1515 |
+
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
|
1516 |
+
|
1517 |
+
query = """SELECT * FROM cypher('%s', $$
|
1518 |
+
UNWIND [%s] AS node_id
|
1519 |
+
MATCH (n:base {entity_id: node_id})
|
1520 |
+
OPTIONAL MATCH (n)-[r]-()
|
1521 |
+
RETURN node_id, count(r) AS degree
|
1522 |
+
$$) AS (node_id text, degree bigint)""" % (
|
1523 |
+
self.graph_name,
|
1524 |
+
formatted_ids
|
1525 |
+
)
|
1526 |
+
|
1527 |
+
results = await self._query(query)
|
1528 |
+
|
1529 |
+
# Build result dictionary
|
1530 |
+
degrees_dict = {}
|
1531 |
+
for result in results:
|
1532 |
+
if result["node_id"] is not None:
|
1533 |
+
degrees_dict[result["node_id"]] = int(result["degree"])
|
1534 |
+
|
1535 |
+
# Ensure all requested node_ids are in the result dictionary
|
1536 |
+
for node_id in node_ids:
|
1537 |
+
if node_id not in degrees_dict:
|
1538 |
+
degrees_dict[node_id] = 0
|
1539 |
+
|
1540 |
+
return degrees_dict
|
1541 |
+
|
1542 |
+
async def edge_degrees_batch(self, edges: list[tuple[str, str]]) -> dict[tuple[str, str], int]:
|
1543 |
+
"""
|
1544 |
+
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
1545 |
+
in batch using the already implemented node_degrees_batch.
|
1546 |
+
|
1547 |
+
Args:
|
1548 |
+
edges: List of (source_node_id, target_node_id) tuples
|
1549 |
+
|
1550 |
+
Returns:
|
1551 |
+
Dictionary mapping edge tuples to their combined degrees
|
1552 |
+
"""
|
1553 |
+
if not edges:
|
1554 |
+
return {}
|
1555 |
+
|
1556 |
+
# Use node_degrees_batch to get all node degrees efficiently
|
1557 |
+
all_nodes = set()
|
1558 |
+
for src, tgt in edges:
|
1559 |
+
all_nodes.add(src)
|
1560 |
+
all_nodes.add(tgt)
|
1561 |
+
|
1562 |
+
node_degrees = await self.node_degrees_batch(list(all_nodes))
|
1563 |
+
|
1564 |
+
# Calculate edge degrees
|
1565 |
+
edge_degrees_dict = {}
|
1566 |
+
for src, tgt in edges:
|
1567 |
+
src_degree = node_degrees.get(src, 0)
|
1568 |
+
tgt_degree = node_degrees.get(tgt, 0)
|
1569 |
+
edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
|
1570 |
+
|
1571 |
+
return edge_degrees_dict
|
1572 |
+
|
1573 |
+
async def get_edges_batch(self, pairs: list[dict[str, str]]) -> dict[tuple[str, str], dict]:
|
1574 |
+
"""
|
1575 |
+
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
1576 |
+
|
1577 |
+
Args:
|
1578 |
+
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
1579 |
+
|
1580 |
+
Returns:
|
1581 |
+
A dictionary mapping (src, tgt) tuples to their edge properties.
|
1582 |
+
"""
|
1583 |
+
if not pairs:
|
1584 |
+
return {}
|
1585 |
+
|
1586 |
+
# 从字典列表中提取源节点和目标节点ID
|
1587 |
+
src_nodes = []
|
1588 |
+
tgt_nodes = []
|
1589 |
+
for pair in pairs:
|
1590 |
+
src_nodes.append(pair["src"].replace('"', ''))
|
1591 |
+
tgt_nodes.append(pair["tgt"].replace('"', ''))
|
1592 |
+
|
1593 |
+
# 构建查询,使用数组索引来匹配源节点和目标节点
|
1594 |
+
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
1595 |
+
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
1596 |
+
|
1597 |
+
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1598 |
+
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
1599 |
+
UNWIND range(0, size(sources)-1) AS i
|
1600 |
+
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
|
1601 |
+
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
1602 |
+
$$) AS (source text, target text, edge_properties agtype)"""
|
1603 |
+
|
1604 |
+
results = await self._query(query)
|
1605 |
+
|
1606 |
+
# 构建结果字典
|
1607 |
+
edges_dict = {}
|
1608 |
+
for result in results:
|
1609 |
+
if result["source"] and result["target"] and result["edge_properties"]:
|
1610 |
+
edges_dict[(result["source"], result["target"])] = result["edge_properties"]
|
1611 |
+
|
1612 |
+
return edges_dict
|
1613 |
+
|
1614 |
+
async def get_nodes_edges_batch(self, node_ids: list[str]) -> dict[str, list[tuple[str, str]]]:
|
1615 |
+
"""
|
1616 |
+
Get all edges for multiple nodes in a single batch operation.
|
1617 |
+
|
1618 |
+
Args:
|
1619 |
+
node_ids: List of node IDs to get edges for
|
1620 |
+
|
1621 |
+
Returns:
|
1622 |
+
Dictionary mapping node IDs to lists of (source, target) edge tuples
|
1623 |
+
"""
|
1624 |
+
if not node_ids:
|
1625 |
+
return {}
|
1626 |
+
|
1627 |
+
# Format node IDs for the query
|
1628 |
+
formatted_ids = ", ".join(['"' + node_id.replace('"', '') + '"' for node_id in node_ids])
|
1629 |
+
|
1630 |
+
query = """SELECT * FROM cypher('%s', $$
|
1631 |
+
UNWIND [%s] AS node_id
|
1632 |
+
MATCH (n:base {entity_id: node_id})
|
1633 |
+
OPTIONAL MATCH (n)-[]-(connected:base)
|
1634 |
+
RETURN node_id, connected.entity_id AS connected_id
|
1635 |
+
$$) AS (node_id text, connected_id text)""" % (
|
1636 |
+
self.graph_name,
|
1637 |
+
formatted_ids
|
1638 |
+
)
|
1639 |
+
|
1640 |
+
results = await self._query(query)
|
1641 |
+
|
1642 |
+
# Build result dictionary
|
1643 |
+
nodes_edges_dict = {node_id: [] for node_id in node_ids}
|
1644 |
+
for result in results:
|
1645 |
+
if result["node_id"] and result["connected_id"]:
|
1646 |
+
nodes_edges_dict[result["node_id"]].append(
|
1647 |
+
(result["node_id"], result["connected_id"])
|
1648 |
+
)
|
1649 |
+
|
1650 |
+
return nodes_edges_dict
|
1651 |
+
|
1652 |
async def get_all_labels(self) -> list[str]:
|
1653 |
"""
|
1654 |
Get all labels (node IDs) in the graph.
|