yangdx
commited on
Commit
·
0ec71c5
1
Parent(s):
4238bff
Changed node label from 'Entity' to 'base' and fix edge deletion error in PostgreSQL AGE graph
Browse files- lightrag/kg/postgres_impl.py +33 -36
lightrag/kg/postgres_impl.py
CHANGED
@@ -1258,7 +1258,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1258 |
entity_name_label = self._encode_graph_label(node_id.strip('"'))
|
1259 |
|
1260 |
query = """SELECT * FROM cypher('%s', $$
|
1261 |
-
MATCH (n:
|
1262 |
RETURN count(n) > 0 AS node_exists
|
1263 |
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
1264 |
|
@@ -1271,7 +1271,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1271 |
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
|
1272 |
|
1273 |
query = """SELECT * FROM cypher('%s', $$
|
1274 |
-
MATCH (a:
|
1275 |
RETURN COUNT(r) > 0 AS edge_exists
|
1276 |
$$) AS (edge_exists bool)""" % (
|
1277 |
self.graph_name,
|
@@ -1286,7 +1286,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1286 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
1287 |
label = self._encode_graph_label(node_id.strip('"'))
|
1288 |
query = """SELECT * FROM cypher('%s', $$
|
1289 |
-
MATCH (n:
|
1290 |
RETURN n
|
1291 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
1292 |
record = await self._query(query)
|
@@ -1301,7 +1301,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1301 |
label = self._encode_graph_label(node_id.strip('"'))
|
1302 |
|
1303 |
query = """SELECT * FROM cypher('%s', $$
|
1304 |
-
MATCH (n:
|
1305 |
RETURN count(x) AS total_edge_count
|
1306 |
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
1307 |
record = (await self._query(query))[0]
|
@@ -1329,7 +1329,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1329 |
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
|
1330 |
|
1331 |
query = """SELECT * FROM cypher('%s', $$
|
1332 |
-
MATCH (a:
|
1333 |
RETURN properties(r) as edge_properties
|
1334 |
LIMIT 1
|
1335 |
$$) AS (edge_properties agtype)""" % (
|
@@ -1351,8 +1351,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1351 |
label = self._encode_graph_label(source_node_id.strip('"'))
|
1352 |
|
1353 |
query = """SELECT * FROM cypher('%s', $$
|
1354 |
-
MATCH (n:
|
1355 |
-
OPTIONAL MATCH (n)-[]-(connected)
|
1356 |
RETURN n, connected
|
1357 |
$$) AS (n agtype, connected agtype)""" % (
|
1358 |
self.graph_name,
|
@@ -1396,7 +1396,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1396 |
properties = node_data
|
1397 |
|
1398 |
query = """SELECT * FROM cypher('%s', $$
|
1399 |
-
MERGE (n:
|
1400 |
SET n += %s
|
1401 |
RETURN n
|
1402 |
$$) AS (n agtype)""" % (
|
@@ -1433,9 +1433,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1433 |
edge_properties = edge_data
|
1434 |
|
1435 |
query = """SELECT * FROM cypher('%s', $$
|
1436 |
-
MATCH (source:
|
1437 |
WITH source
|
1438 |
-
MATCH (target:
|
1439 |
MERGE (source)-[r:DIRECTED]->(target)
|
1440 |
SET r += %s
|
1441 |
RETURN r
|
@@ -1466,7 +1466,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1466 |
label = self._encode_graph_label(node_id.strip('"'))
|
1467 |
|
1468 |
query = """SELECT * FROM cypher('%s', $$
|
1469 |
-
MATCH (n:
|
1470 |
DETACH DELETE n
|
1471 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
1472 |
|
@@ -1489,8 +1489,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1489 |
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
|
1490 |
|
1491 |
query = """SELECT * FROM cypher('%s', $$
|
1492 |
-
MATCH (n:
|
1493 |
-
WHERE n.
|
1494 |
DETACH DELETE n
|
1495 |
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
1496 |
|
@@ -1507,26 +1507,21 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1507 |
Args:
|
1508 |
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
1509 |
"""
|
1510 |
-
|
1511 |
-
(
|
1512 |
-
|
1513 |
-
self._encode_graph_label(tgt.strip('"')),
|
1514 |
-
)
|
1515 |
-
for src, tgt in edges
|
1516 |
-
]
|
1517 |
-
edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
|
1518 |
|
1519 |
-
|
1520 |
-
|
1521 |
-
|
1522 |
-
|
1523 |
-
$$) AS (r agtype)""" % (self.graph_name, edge_list)
|
1524 |
|
1525 |
-
|
1526 |
-
|
1527 |
-
|
1528 |
-
|
1529 |
-
|
|
|
1530 |
|
1531 |
async def get_all_labels(self) -> list[str]:
|
1532 |
"""
|
@@ -1537,8 +1532,10 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1537 |
"""
|
1538 |
query = (
|
1539 |
"""SELECT * FROM cypher('%s', $$
|
1540 |
-
MATCH (n:
|
1541 |
-
|
|
|
|
|
1542 |
$$) AS (label text)"""
|
1543 |
% self.graph_name
|
1544 |
)
|
@@ -1584,15 +1581,15 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1584 |
# Build the query based on whether we want the full graph or a specific subgraph.
|
1585 |
if node_label == "*":
|
1586 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1587 |
-
MATCH (n:
|
1588 |
-
OPTIONAL MATCH (n)-[r]->(m:
|
1589 |
RETURN n, r, m
|
1590 |
LIMIT {MAX_GRAPH_NODES}
|
1591 |
$$) AS (n agtype, r agtype, m agtype)"""
|
1592 |
else:
|
1593 |
encoded_label = self._encode_graph_label(node_label.strip('"'))
|
1594 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1595 |
-
MATCH (n:
|
1596 |
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
1597 |
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
1598 |
LIMIT {MAX_GRAPH_NODES}
|
|
|
1258 |
entity_name_label = self._encode_graph_label(node_id.strip('"'))
|
1259 |
|
1260 |
query = """SELECT * FROM cypher('%s', $$
|
1261 |
+
MATCH (n:base {node_id: "%s"})
|
1262 |
RETURN count(n) > 0 AS node_exists
|
1263 |
$$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
|
1264 |
|
|
|
1271 |
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
|
1272 |
|
1273 |
query = """SELECT * FROM cypher('%s', $$
|
1274 |
+
MATCH (a:base {node_id: "%s"})-[r]-(b:base {node_id: "%s"})
|
1275 |
RETURN COUNT(r) > 0 AS edge_exists
|
1276 |
$$) AS (edge_exists bool)""" % (
|
1277 |
self.graph_name,
|
|
|
1286 |
async def get_node(self, node_id: str) -> dict[str, str] | None:
|
1287 |
label = self._encode_graph_label(node_id.strip('"'))
|
1288 |
query = """SELECT * FROM cypher('%s', $$
|
1289 |
+
MATCH (n:base {node_id: "%s"})
|
1290 |
RETURN n
|
1291 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
1292 |
record = await self._query(query)
|
|
|
1301 |
label = self._encode_graph_label(node_id.strip('"'))
|
1302 |
|
1303 |
query = """SELECT * FROM cypher('%s', $$
|
1304 |
+
MATCH (n:base {node_id: "%s"})-[]->(x)
|
1305 |
RETURN count(x) AS total_edge_count
|
1306 |
$$) AS (total_edge_count integer)""" % (self.graph_name, label)
|
1307 |
record = (await self._query(query))[0]
|
|
|
1329 |
tgt_label = self._encode_graph_label(target_node_id.strip('"'))
|
1330 |
|
1331 |
query = """SELECT * FROM cypher('%s', $$
|
1332 |
+
MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"})
|
1333 |
RETURN properties(r) as edge_properties
|
1334 |
LIMIT 1
|
1335 |
$$) AS (edge_properties agtype)""" % (
|
|
|
1351 |
label = self._encode_graph_label(source_node_id.strip('"'))
|
1352 |
|
1353 |
query = """SELECT * FROM cypher('%s', $$
|
1354 |
+
MATCH (n:base {node_id: "%s"})
|
1355 |
+
OPTIONAL MATCH (n)-[]-(connected:base)
|
1356 |
RETURN n, connected
|
1357 |
$$) AS (n agtype, connected agtype)""" % (
|
1358 |
self.graph_name,
|
|
|
1396 |
properties = node_data
|
1397 |
|
1398 |
query = """SELECT * FROM cypher('%s', $$
|
1399 |
+
MERGE (n:base {node_id: "%s"})
|
1400 |
SET n += %s
|
1401 |
RETURN n
|
1402 |
$$) AS (n agtype)""" % (
|
|
|
1433 |
edge_properties = edge_data
|
1434 |
|
1435 |
query = """SELECT * FROM cypher('%s', $$
|
1436 |
+
MATCH (source:base {node_id: "%s"})
|
1437 |
WITH source
|
1438 |
+
MATCH (target:base {node_id: "%s"})
|
1439 |
MERGE (source)-[r:DIRECTED]->(target)
|
1440 |
SET r += %s
|
1441 |
RETURN r
|
|
|
1466 |
label = self._encode_graph_label(node_id.strip('"'))
|
1467 |
|
1468 |
query = """SELECT * FROM cypher('%s', $$
|
1469 |
+
MATCH (n:base {entity_id: "%s"})
|
1470 |
DETACH DELETE n
|
1471 |
$$) AS (n agtype)""" % (self.graph_name, label)
|
1472 |
|
|
|
1489 |
node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
|
1490 |
|
1491 |
query = """SELECT * FROM cypher('%s', $$
|
1492 |
+
MATCH (n:base)
|
1493 |
+
WHERE n.nentity_id IN [%s]
|
1494 |
DETACH DELETE n
|
1495 |
$$) AS (n agtype)""" % (self.graph_name, node_id_list)
|
1496 |
|
|
|
1507 |
Args:
|
1508 |
edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
|
1509 |
"""
|
1510 |
+
for source, target in edges:
|
1511 |
+
src_label = self._encode_graph_label(source.strip('"'))
|
1512 |
+
tgt_label = self._encode_graph_label(target.strip('"'))
|
|
|
|
|
|
|
|
|
|
|
1513 |
|
1514 |
+
query = """SELECT * FROM cypher('%s', $$
|
1515 |
+
MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"})
|
1516 |
+
DELETE r
|
1517 |
+
$$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
|
|
|
1518 |
|
1519 |
+
try:
|
1520 |
+
await self._query(query, readonly=False)
|
1521 |
+
logger.debug(f"Deleted edge from '{source}' to '{target}'")
|
1522 |
+
except Exception as e:
|
1523 |
+
logger.error(f"Error during edge deletion: {str(e)}")
|
1524 |
+
raise
|
1525 |
|
1526 |
async def get_all_labels(self) -> list[str]:
|
1527 |
"""
|
|
|
1532 |
"""
|
1533 |
query = (
|
1534 |
"""SELECT * FROM cypher('%s', $$
|
1535 |
+
MATCH (n:base)
|
1536 |
+
WHERE n.entity_id IS NOT NULL
|
1537 |
+
RETURN DISTINCT n.entity_id AS label
|
1538 |
+
ORDER BY label
|
1539 |
$$) AS (label text)"""
|
1540 |
% self.graph_name
|
1541 |
)
|
|
|
1581 |
# Build the query based on whether we want the full graph or a specific subgraph.
|
1582 |
if node_label == "*":
|
1583 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1584 |
+
MATCH (n:base)
|
1585 |
+
OPTIONAL MATCH (n)-[r]->(m:base)
|
1586 |
RETURN n, r, m
|
1587 |
LIMIT {MAX_GRAPH_NODES}
|
1588 |
$$) AS (n agtype, r agtype, m agtype)"""
|
1589 |
else:
|
1590 |
encoded_label = self._encode_graph_label(node_label.strip('"'))
|
1591 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1592 |
+
MATCH (n:base {{entity_id: "{encoded_label}"}})
|
1593 |
OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
|
1594 |
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
1595 |
LIMIT {MAX_GRAPH_NODES}
|