zrguo commited on
Commit
0039818
·
unverified ·
2 Parent(s): 2430677 bd8c26c

Merge pull request #576 from ShanGor/main

Browse files

Revise the AGE implementation on get_node_edges, to align with Neo4j behavior.

lightrag/kg/postgres_impl.py CHANGED
@@ -141,13 +141,16 @@ class PostgreSQLDB:
141
  await connection.execute(sql)
142
  else:
143
  await connection.execute(sql, *data.values())
144
- except asyncpg.exceptions.UniqueViolationError as e:
 
 
 
145
  if upsert:
146
  print("Key value duplicate, but upsert succeeded.")
147
  else:
148
  logger.error(f"Upsert error: {e}")
149
  except Exception as e:
150
- logger.error(f"PostgreSQL database error: {e}")
151
  print(sql)
152
  print(data)
153
  raise
@@ -885,7 +888,12 @@ class PGGraphStorage(BaseGraphStorage):
885
  )
886
 
887
  if source_label and target_label:
888
- edges.append((source_label, target_label))
 
 
 
 
 
889
 
890
  return edges
891
 
 
141
  await connection.execute(sql)
142
  else:
143
  await connection.execute(sql, *data.values())
144
+ except (
145
+ asyncpg.exceptions.UniqueViolationError,
146
+ asyncpg.exceptions.DuplicateTableError,
147
+ ) as e:
148
  if upsert:
149
  print("Key value duplicate, but upsert succeeded.")
150
  else:
151
  logger.error(f"Upsert error: {e}")
152
  except Exception as e:
153
+ logger.error(f"PostgreSQL database error: {e.__class__} - {e}")
154
  print(sql)
155
  print(data)
156
  raise
 
888
  )
889
 
890
  if source_label and target_label:
891
+ edges.append(
892
+ (
893
+ PGGraphStorage._decode_graph_label(source_label),
894
+ PGGraphStorage._decode_graph_label(target_label),
895
+ )
896
+ )
897
 
898
  return edges
899
 
lightrag/kg/postgres_impl_test.py CHANGED
@@ -61,7 +61,7 @@ db = PostgreSQLDB(
61
  "port": 15432,
62
  "user": "rag",
63
  "password": "rag",
64
- "database": "rag",
65
  }
66
  )
67
 
@@ -74,8 +74,12 @@ async def query_with_age():
74
  embedding_func=None,
75
  )
76
  graph.db = db
77
- res = await graph.get_node('"CHRISTMAS-TIME"')
78
  print("Node is: ", res)
 
 
 
 
79
 
80
 
81
  async def create_edge_with_age():
 
61
  "port": 15432,
62
  "user": "rag",
63
  "password": "rag",
64
+ "database": "r1",
65
  }
66
  )
67
 
 
74
  embedding_func=None,
75
  )
76
  graph.db = db
77
+ res = await graph.get_node('"A CHRISTMAS CAROL"')
78
  print("Node is: ", res)
79
+ res = await graph.get_edge('"A CHRISTMAS CAROL"', "PROJECT GUTENBERG")
80
+ print("Edge is: ", res)
81
+ res = await graph.get_node_edges('"SCROOGE"')
82
+ print("Node Edges are: ", res)
83
 
84
 
85
  async def create_edge_with_age():