yangdx commited on
Commit
0ec71c5
·
1 Parent(s): 4238bff

Changed node label from 'Entity' to 'base' and fix edge deletion error in PostgreSQL AGE graph

Browse files
Files changed (1) hide show
  1. lightrag/kg/postgres_impl.py +33 -36
lightrag/kg/postgres_impl.py CHANGED
@@ -1258,7 +1258,7 @@ class PGGraphStorage(BaseGraphStorage):
1258
  entity_name_label = self._encode_graph_label(node_id.strip('"'))
1259
 
1260
  query = """SELECT * FROM cypher('%s', $$
1261
- MATCH (n:Entity {node_id: "%s"})
1262
  RETURN count(n) > 0 AS node_exists
1263
  $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
1264
 
@@ -1271,7 +1271,7 @@ class PGGraphStorage(BaseGraphStorage):
1271
  tgt_label = self._encode_graph_label(target_node_id.strip('"'))
1272
 
1273
  query = """SELECT * FROM cypher('%s', $$
1274
- MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
1275
  RETURN COUNT(r) > 0 AS edge_exists
1276
  $$) AS (edge_exists bool)""" % (
1277
  self.graph_name,
@@ -1286,7 +1286,7 @@ class PGGraphStorage(BaseGraphStorage):
1286
  async def get_node(self, node_id: str) -> dict[str, str] | None:
1287
  label = self._encode_graph_label(node_id.strip('"'))
1288
  query = """SELECT * FROM cypher('%s', $$
1289
- MATCH (n:Entity {node_id: "%s"})
1290
  RETURN n
1291
  $$) AS (n agtype)""" % (self.graph_name, label)
1292
  record = await self._query(query)
@@ -1301,7 +1301,7 @@ class PGGraphStorage(BaseGraphStorage):
1301
  label = self._encode_graph_label(node_id.strip('"'))
1302
 
1303
  query = """SELECT * FROM cypher('%s', $$
1304
- MATCH (n:Entity {node_id: "%s"})-[]->(x)
1305
  RETURN count(x) AS total_edge_count
1306
  $$) AS (total_edge_count integer)""" % (self.graph_name, label)
1307
  record = (await self._query(query))[0]
@@ -1329,7 +1329,7 @@ class PGGraphStorage(BaseGraphStorage):
1329
  tgt_label = self._encode_graph_label(target_node_id.strip('"'))
1330
 
1331
  query = """SELECT * FROM cypher('%s', $$
1332
- MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
1333
  RETURN properties(r) as edge_properties
1334
  LIMIT 1
1335
  $$) AS (edge_properties agtype)""" % (
@@ -1351,8 +1351,8 @@ class PGGraphStorage(BaseGraphStorage):
1351
  label = self._encode_graph_label(source_node_id.strip('"'))
1352
 
1353
  query = """SELECT * FROM cypher('%s', $$
1354
- MATCH (n:Entity {node_id: "%s"})
1355
- OPTIONAL MATCH (n)-[]-(connected)
1356
  RETURN n, connected
1357
  $$) AS (n agtype, connected agtype)""" % (
1358
  self.graph_name,
@@ -1396,7 +1396,7 @@ class PGGraphStorage(BaseGraphStorage):
1396
  properties = node_data
1397
 
1398
  query = """SELECT * FROM cypher('%s', $$
1399
- MERGE (n:Entity {node_id: "%s"})
1400
  SET n += %s
1401
  RETURN n
1402
  $$) AS (n agtype)""" % (
@@ -1433,9 +1433,9 @@ class PGGraphStorage(BaseGraphStorage):
1433
  edge_properties = edge_data
1434
 
1435
  query = """SELECT * FROM cypher('%s', $$
1436
- MATCH (source:Entity {node_id: "%s"})
1437
  WITH source
1438
- MATCH (target:Entity {node_id: "%s"})
1439
  MERGE (source)-[r:DIRECTED]->(target)
1440
  SET r += %s
1441
  RETURN r
@@ -1466,7 +1466,7 @@ class PGGraphStorage(BaseGraphStorage):
1466
  label = self._encode_graph_label(node_id.strip('"'))
1467
 
1468
  query = """SELECT * FROM cypher('%s', $$
1469
- MATCH (n:Entity {node_id: "%s"})
1470
  DETACH DELETE n
1471
  $$) AS (n agtype)""" % (self.graph_name, label)
1472
 
@@ -1489,8 +1489,8 @@ class PGGraphStorage(BaseGraphStorage):
1489
  node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
1490
 
1491
  query = """SELECT * FROM cypher('%s', $$
1492
- MATCH (n:Entity)
1493
- WHERE n.node_id IN [%s]
1494
  DETACH DELETE n
1495
  $$) AS (n agtype)""" % (self.graph_name, node_id_list)
1496
 
@@ -1507,26 +1507,21 @@ class PGGraphStorage(BaseGraphStorage):
1507
  Args:
1508
  edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
1509
  """
1510
- encoded_edges = [
1511
- (
1512
- self._encode_graph_label(src.strip('"')),
1513
- self._encode_graph_label(tgt.strip('"')),
1514
- )
1515
- for src, tgt in edges
1516
- ]
1517
- edge_list = ", ".join([f'["{src}", "{tgt}"]' for src, tgt in encoded_edges])
1518
 
1519
- query = """SELECT * FROM cypher('%s', $$
1520
- MATCH (a:Entity)-[r]->(b:Entity)
1521
- WHERE [a.node_id, b.node_id] IN [%s]
1522
- DELETE r
1523
- $$) AS (r agtype)""" % (self.graph_name, edge_list)
1524
 
1525
- try:
1526
- await self._query(query, readonly=False)
1527
- except Exception as e:
1528
- logger.error("Error during edge removal: {%s}", e)
1529
- raise
 
1530
 
1531
  async def get_all_labels(self) -> list[str]:
1532
  """
@@ -1537,8 +1532,10 @@ class PGGraphStorage(BaseGraphStorage):
1537
  """
1538
  query = (
1539
  """SELECT * FROM cypher('%s', $$
1540
- MATCH (n:Entity)
1541
- RETURN DISTINCT n.node_id AS label
 
 
1542
  $$) AS (label text)"""
1543
  % self.graph_name
1544
  )
@@ -1584,15 +1581,15 @@ class PGGraphStorage(BaseGraphStorage):
1584
  # Build the query based on whether we want the full graph or a specific subgraph.
1585
  if node_label == "*":
1586
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1587
- MATCH (n:Entity)
1588
- OPTIONAL MATCH (n)-[r]->(m:Entity)
1589
  RETURN n, r, m
1590
  LIMIT {MAX_GRAPH_NODES}
1591
  $$) AS (n agtype, r agtype, m agtype)"""
1592
  else:
1593
  encoded_label = self._encode_graph_label(node_label.strip('"'))
1594
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1595
- MATCH (n:Entity {{node_id: "{encoded_label}"}})
1596
  OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
1597
  RETURN nodes(p) AS nodes, relationships(p) AS relationships
1598
  LIMIT {MAX_GRAPH_NODES}
 
1258
  entity_name_label = self._encode_graph_label(node_id.strip('"'))
1259
 
1260
  query = """SELECT * FROM cypher('%s', $$
1261
+ MATCH (n:base {node_id: "%s"})
1262
  RETURN count(n) > 0 AS node_exists
1263
  $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
1264
 
 
1271
  tgt_label = self._encode_graph_label(target_node_id.strip('"'))
1272
 
1273
  query = """SELECT * FROM cypher('%s', $$
1274
+ MATCH (a:base {node_id: "%s"})-[r]-(b:base {node_id: "%s"})
1275
  RETURN COUNT(r) > 0 AS edge_exists
1276
  $$) AS (edge_exists bool)""" % (
1277
  self.graph_name,
 
1286
  async def get_node(self, node_id: str) -> dict[str, str] | None:
1287
  label = self._encode_graph_label(node_id.strip('"'))
1288
  query = """SELECT * FROM cypher('%s', $$
1289
+ MATCH (n:base {node_id: "%s"})
1290
  RETURN n
1291
  $$) AS (n agtype)""" % (self.graph_name, label)
1292
  record = await self._query(query)
 
1301
  label = self._encode_graph_label(node_id.strip('"'))
1302
 
1303
  query = """SELECT * FROM cypher('%s', $$
1304
+ MATCH (n:base {node_id: "%s"})-[]->(x)
1305
  RETURN count(x) AS total_edge_count
1306
  $$) AS (total_edge_count integer)""" % (self.graph_name, label)
1307
  record = (await self._query(query))[0]
 
1329
  tgt_label = self._encode_graph_label(target_node_id.strip('"'))
1330
 
1331
  query = """SELECT * FROM cypher('%s', $$
1332
+ MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"})
1333
  RETURN properties(r) as edge_properties
1334
  LIMIT 1
1335
  $$) AS (edge_properties agtype)""" % (
 
1351
  label = self._encode_graph_label(source_node_id.strip('"'))
1352
 
1353
  query = """SELECT * FROM cypher('%s', $$
1354
+ MATCH (n:base {node_id: "%s"})
1355
+ OPTIONAL MATCH (n)-[]-(connected:base)
1356
  RETURN n, connected
1357
  $$) AS (n agtype, connected agtype)""" % (
1358
  self.graph_name,
 
1396
  properties = node_data
1397
 
1398
  query = """SELECT * FROM cypher('%s', $$
1399
+ MERGE (n:base {node_id: "%s"})
1400
  SET n += %s
1401
  RETURN n
1402
  $$) AS (n agtype)""" % (
 
1433
  edge_properties = edge_data
1434
 
1435
  query = """SELECT * FROM cypher('%s', $$
1436
+ MATCH (source:base {node_id: "%s"})
1437
  WITH source
1438
+ MATCH (target:base {node_id: "%s"})
1439
  MERGE (source)-[r:DIRECTED]->(target)
1440
  SET r += %s
1441
  RETURN r
 
1466
  label = self._encode_graph_label(node_id.strip('"'))
1467
 
1468
  query = """SELECT * FROM cypher('%s', $$
1469
+ MATCH (n:base {entity_id: "%s"})
1470
  DETACH DELETE n
1471
  $$) AS (n agtype)""" % (self.graph_name, label)
1472
 
 
1489
  node_id_list = ", ".join([f'"{node_id}"' for node_id in encoded_node_ids])
1490
 
1491
  query = """SELECT * FROM cypher('%s', $$
1492
+ MATCH (n:base)
1493
+ WHERE n.nentity_id IN [%s]
1494
  DETACH DELETE n
1495
  $$) AS (n agtype)""" % (self.graph_name, node_id_list)
1496
 
 
1507
  Args:
1508
  edges (list[tuple[str, str]]): A list of edges to remove, where each edge is a tuple of (source_node_id, target_node_id).
1509
  """
1510
+ for source, target in edges:
1511
+ src_label = self._encode_graph_label(source.strip('"'))
1512
+ tgt_label = self._encode_graph_label(target.strip('"'))
 
 
 
 
 
1513
 
1514
+ query = """SELECT * FROM cypher('%s', $$
1515
+ MATCH (a:base {node_id: "%s"})-[r]->(b:base {node_id: "%s"})
1516
+ DELETE r
1517
+ $$) AS (r agtype)""" % (self.graph_name, src_label, tgt_label)
 
1518
 
1519
+ try:
1520
+ await self._query(query, readonly=False)
1521
+ logger.debug(f"Deleted edge from '{source}' to '{target}'")
1522
+ except Exception as e:
1523
+ logger.error(f"Error during edge deletion: {str(e)}")
1524
+ raise
1525
 
1526
  async def get_all_labels(self) -> list[str]:
1527
  """
 
1532
  """
1533
  query = (
1534
  """SELECT * FROM cypher('%s', $$
1535
+ MATCH (n:base)
1536
+ WHERE n.entity_id IS NOT NULL
1537
+ RETURN DISTINCT n.entity_id AS label
1538
+ ORDER BY label
1539
  $$) AS (label text)"""
1540
  % self.graph_name
1541
  )
 
1581
  # Build the query based on whether we want the full graph or a specific subgraph.
1582
  if node_label == "*":
1583
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1584
+ MATCH (n:base)
1585
+ OPTIONAL MATCH (n)-[r]->(m:base)
1586
  RETURN n, r, m
1587
  LIMIT {MAX_GRAPH_NODES}
1588
  $$) AS (n agtype, r agtype, m agtype)"""
1589
  else:
1590
  encoded_label = self._encode_graph_label(node_label.strip('"'))
1591
  query = f"""SELECT * FROM cypher('{self.graph_name}', $$
1592
+ MATCH (n:base {{entity_id: "{encoded_label}"}})
1593
  OPTIONAL MATCH p = (n)-[*..{max_depth}]-(m)
1594
  RETURN nodes(p) AS nodes, relationships(p) AS relationships
1595
  LIMIT {MAX_GRAPH_NODES}