samuel-z-chen commited on
Commit
3bda675
·
1 Parent(s): c55e3cb

fix the postgres get all labels and get knowledge graph

Browse files
examples/test_postgres.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from lightrag.kg.postgres_impl import PGGraphStorage
4
+ from lightrag.llm.ollama import ollama_embedding
5
+ from lightrag.utils import EmbeddingFunc
6
+
7
+ #########
8
+ # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
9
+ # import nest_asyncio
10
+ # nest_asyncio.apply()
11
+ #########
12
+
13
+ WORKING_DIR = "./local_neo4jWorkDir"
14
+
15
+ if not os.path.exists(WORKING_DIR):
16
+ os.mkdir(WORKING_DIR)
17
+
18
+ # AGE
19
+ os.environ["AGE_GRAPH_NAME"] = "dickens"
20
+
21
+ os.environ["POSTGRES_HOST"] = "localhost"
22
+ os.environ["POSTGRES_PORT"] = "15432"
23
+ os.environ["POSTGRES_USER"] = "rag"
24
+ os.environ["POSTGRES_PASSWORD"] = "rag"
25
+ os.environ["POSTGRES_DATABASE"] = "rag"
26
+
27
+
28
+ async def main():
29
+ graph_db = PGGraphStorage(
30
+ namespace="dickens",
31
+ embedding_func=EmbeddingFunc(
32
+ embedding_dim=1024,
33
+ max_token_size=8192,
34
+ func=lambda texts: ollama_embedding(
35
+ texts, embed_model="bge-m3", host="http://localhost:11434"
36
+ ),
37
+ ),
38
+ global_config={},
39
+ )
40
+ await graph_db.initialize()
41
+ labels = await graph_db.get_all_labels()
42
+ print("all labels", labels)
43
+
44
+ res = await graph_db.get_knowledge_graph("FEZZIWIG")
45
+ print("knowledge graphs", res)
46
+
47
+ await graph_db.finalize()
48
+
49
+
50
+ if __name__ == "__main__":
51
+ asyncio.run(main())
lightrag/kg/postgres_impl.py CHANGED
@@ -810,42 +810,85 @@ class PGGraphStorage(BaseGraphStorage):
810
  v = record[k]
811
  # agtype comes back '{key: value}::type' which must be parsed
812
  if isinstance(v, str) and "::" in v:
813
- dtype = v.split("::")[-1]
814
- v = v.split("::")[0]
815
- if dtype == "vertex":
816
- vertex = json.loads(v)
817
- vertices[vertex["id"]] = vertex.get("properties")
 
 
 
 
 
 
 
 
818
 
819
  # iterate returned fields and parse appropriately
820
  for k in record.keys():
821
  v = record[k]
822
  if isinstance(v, str) and "::" in v:
823
- dtype = v.split("::")[-1]
824
- v = v.split("::")[0]
825
- else:
826
- dtype = ""
827
-
828
- if dtype == "vertex":
829
- vertex = json.loads(v)
830
- field = vertex.get("properties")
831
- if not field:
832
- field = {}
833
- field["label"] = PGGraphStorage._decode_graph_label(field["node_id"])
834
- d[k] = field
835
- # convert edge from id-label->id by replacing id with node information
836
- # we only do this if the vertex was also returned in the query
837
- # this is an attempt to be consistent with neo4j implementation
838
- elif dtype == "edge":
839
- edge = json.loads(v)
840
- d[k] = (
841
- vertices.get(edge["start_id"], {}),
842
- edge[
843
- "label"
844
- ], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
845
- vertices.get(edge["end_id"], {}),
846
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
847
  else:
848
- d[k] = json.loads(v) if isinstance(v, str) else v
 
 
 
849
 
850
  return d
851
 
@@ -1319,7 +1362,7 @@ class PGGraphStorage(BaseGraphStorage):
1319
  OPTIONAL MATCH p = (n)-[*..%d]-(m)
1320
  RETURN nodes(p) AS nodes, relationships(p) AS relationships
1321
  LIMIT %d
1322
- $$) AS (nodes agtype[], relationships agtype[])""" % (
1323
  self.graph_name,
1324
  encoded_node_label,
1325
  max_depth,
@@ -1328,17 +1371,23 @@ class PGGraphStorage(BaseGraphStorage):
1328
 
1329
  results = await self._query(query)
1330
 
1331
- nodes = set()
1332
  edges = []
 
1333
 
1334
  for result in results:
1335
  if node_label == "*":
1336
  if result["n"]:
1337
  node = result["n"]
1338
- nodes.add(self._decode_graph_label(node["node_id"]))
 
 
 
1339
  if result["m"]:
1340
  node = result["m"]
1341
- nodes.add(self._decode_graph_label(node["node_id"]))
 
 
1342
  if result["r"]:
1343
  edge = result["r"]
1344
  src_id = self._decode_graph_label(edge["start_id"])
@@ -1347,16 +1396,36 @@ class PGGraphStorage(BaseGraphStorage):
1347
  else:
1348
  if result["nodes"]:
1349
  for node in result["nodes"]:
1350
- nodes.add(self._decode_graph_label(node["node_id"]))
 
 
 
1351
  if result["relationships"]:
1352
- for edge in result["relationships"]:
1353
- src_id = self._decode_graph_label(edge["start_id"])
1354
- tgt_id = self._decode_graph_label(edge["end_id"])
1355
- edges.append((src_id, tgt_id))
 
 
 
 
 
 
 
1356
 
1357
  kg = KnowledgeGraph(
1358
- nodes=[KnowledgeGraphNode(id=node_id) for node_id in nodes],
1359
- edges=[KnowledgeGraphEdge(source=src, target=tgt) for src, tgt in edges],
 
 
 
 
 
 
 
 
 
 
1360
  )
1361
 
1362
  return kg
 
810
  v = record[k]
811
  # agtype comes back '{key: value}::type' which must be parsed
812
  if isinstance(v, str) and "::" in v:
813
+ if v.startswith("[") and v.endswith("]"):
814
+ if "::vertex" not in v:
815
+ continue
816
+ v = v.replace("::vertex", "")
817
+ vertexes = json.loads(v)
818
+ for vertex in vertexes:
819
+ vertices[vertex["id"]] = vertex.get("properties")
820
+ else:
821
+ dtype = v.split("::")[-1]
822
+ v = v.split("::")[0]
823
+ if dtype == "vertex":
824
+ vertex = json.loads(v)
825
+ vertices[vertex["id"]] = vertex.get("properties")
826
 
827
  # iterate returned fields and parse appropriately
828
  for k in record.keys():
829
  v = record[k]
830
  if isinstance(v, str) and "::" in v:
831
+ if v.startswith("[") and v.endswith("]"):
832
+ if "::vertex" in v:
833
+ v = v.replace("::vertex", "")
834
+ vertexes = json.loads(v)
835
+ dl = []
836
+ for vertex in vertexes:
837
+ prop = vertex.get("properties")
838
+ if not prop:
839
+ prop = {}
840
+ prop["label"] = PGGraphStorage._decode_graph_label(
841
+ prop["node_id"]
842
+ )
843
+ dl.append(prop)
844
+ d[k] = dl
845
+
846
+ elif "::edge" in v:
847
+ v = v.replace("::edge", "")
848
+ edges = json.loads(v)
849
+ dl = []
850
+ for edge in edges:
851
+ dl.append(
852
+ (
853
+ vertices[edge["start_id"]],
854
+ edge["label"],
855
+ vertices[edge["end_id"]],
856
+ )
857
+ )
858
+ d[k] = dl
859
+ else:
860
+ print("WARNING: unsupported type")
861
+ continue
862
+
863
+ else:
864
+ dtype = v.split("::")[-1]
865
+ v = v.split("::")[0]
866
+ if dtype == "vertex":
867
+ vertex = json.loads(v)
868
+ field = vertex.get("properties")
869
+ if not field:
870
+ field = {}
871
+ field["label"] = PGGraphStorage._decode_graph_label(
872
+ field["node_id"]
873
+ )
874
+ d[k] = field
875
+ # convert edge from id-label->id by replacing id with node information
876
+ # we only do this if the vertex was also returned in the query
877
+ # this is an attempt to be consistent with neo4j implementation
878
+ elif dtype == "edge":
879
+ edge = json.loads(v)
880
+ d[k] = (
881
+ vertices.get(edge["start_id"], {}),
882
+ edge[
883
+ "label"
884
+ ], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
885
+ vertices.get(edge["end_id"], {}),
886
+ )
887
  else:
888
+ if v is None or (v.count("{") < 1 and v.count("[") < 1):
889
+ d[k] = v
890
+ else:
891
+ d[k] = json.loads(v) if isinstance(v, str) else v
892
 
893
  return d
894
 
 
1362
  OPTIONAL MATCH p = (n)-[*..%d]-(m)
1363
  RETURN nodes(p) AS nodes, relationships(p) AS relationships
1364
  LIMIT %d
1365
+ $$) AS (nodes agtype, relationships agtype)""" % (
1366
  self.graph_name,
1367
  encoded_node_label,
1368
  max_depth,
 
1371
 
1372
  results = await self._query(query)
1373
 
1374
+ nodes = {}
1375
  edges = []
1376
+ unique_edge_ids = set()
1377
 
1378
  for result in results:
1379
  if node_label == "*":
1380
  if result["n"]:
1381
  node = result["n"]
1382
+ node_id = self._decode_graph_label(node["node_id"])
1383
+ if node_id not in nodes:
1384
+ nodes[node_id] = node
1385
+
1386
  if result["m"]:
1387
  node = result["m"]
1388
+ node_id = self._decode_graph_label(node["node_id"])
1389
+ if node_id not in nodes:
1390
+ nodes[node_id] = node
1391
  if result["r"]:
1392
  edge = result["r"]
1393
  src_id = self._decode_graph_label(edge["start_id"])
 
1396
  else:
1397
  if result["nodes"]:
1398
  for node in result["nodes"]:
1399
+ node_id = self._decode_graph_label(node["node_id"])
1400
+ if node_id not in nodes:
1401
+ nodes[node_id] = node
1402
+
1403
  if result["relationships"]:
1404
+ for edge in result["relationships"]: # src --DIRECTED--> target
1405
+ src_id = self._decode_graph_label(edge[0]["node_id"])
1406
+ tgt_id = self._decode_graph_label(edge[2]["node_id"])
1407
+ id = src_id + "," + tgt_id
1408
+ if id in unique_edge_ids:
1409
+ continue
1410
+ else:
1411
+ unique_edge_ids.add(id)
1412
+ edges.append(
1413
+ (id, src_id, tgt_id, {"source": edge[0], "target": edge[2]})
1414
+ )
1415
 
1416
  kg = KnowledgeGraph(
1417
+ nodes=[
1418
+ KnowledgeGraphNode(
1419
+ id=node_id, labels=[node_id], properties=nodes[node_id]
1420
+ )
1421
+ for node_id in nodes
1422
+ ],
1423
+ edges=[
1424
+ KnowledgeGraphEdge(
1425
+ id=id, type="DIRECTED", source=src, target=tgt, properties=props
1426
+ )
1427
+ for id, src, tgt, props in edges
1428
+ ],
1429
  )
1430
 
1431
  return kg