Commit
·
3bda675
1
Parent(s):
c55e3cb
fix the postgres get all labels and get knowledge graph
Browse files- examples/test_postgres.py +51 -0
- lightrag/kg/postgres_impl.py +110 -41
examples/test_postgres.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import asyncio
|
3 |
+
from lightrag.kg.postgres_impl import PGGraphStorage
|
4 |
+
from lightrag.llm.ollama import ollama_embedding
|
5 |
+
from lightrag.utils import EmbeddingFunc
|
6 |
+
|
7 |
+
#########
|
8 |
+
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
|
9 |
+
# import nest_asyncio
|
10 |
+
# nest_asyncio.apply()
|
11 |
+
#########
|
12 |
+
|
13 |
+
WORKING_DIR = "./local_neo4jWorkDir"
|
14 |
+
|
15 |
+
if not os.path.exists(WORKING_DIR):
|
16 |
+
os.mkdir(WORKING_DIR)
|
17 |
+
|
18 |
+
# AGE
|
19 |
+
os.environ["AGE_GRAPH_NAME"] = "dickens"
|
20 |
+
|
21 |
+
os.environ["POSTGRES_HOST"] = "localhost"
|
22 |
+
os.environ["POSTGRES_PORT"] = "15432"
|
23 |
+
os.environ["POSTGRES_USER"] = "rag"
|
24 |
+
os.environ["POSTGRES_PASSWORD"] = "rag"
|
25 |
+
os.environ["POSTGRES_DATABASE"] = "rag"
|
26 |
+
|
27 |
+
|
28 |
+
async def main():
|
29 |
+
graph_db = PGGraphStorage(
|
30 |
+
namespace="dickens",
|
31 |
+
embedding_func=EmbeddingFunc(
|
32 |
+
embedding_dim=1024,
|
33 |
+
max_token_size=8192,
|
34 |
+
func=lambda texts: ollama_embedding(
|
35 |
+
texts, embed_model="bge-m3", host="http://localhost:11434"
|
36 |
+
),
|
37 |
+
),
|
38 |
+
global_config={},
|
39 |
+
)
|
40 |
+
await graph_db.initialize()
|
41 |
+
labels = await graph_db.get_all_labels()
|
42 |
+
print("all labels", labels)
|
43 |
+
|
44 |
+
res = await graph_db.get_knowledge_graph("FEZZIWIG")
|
45 |
+
print("knowledge graphs", res)
|
46 |
+
|
47 |
+
await graph_db.finalize()
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
asyncio.run(main())
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -810,42 +810,85 @@ class PGGraphStorage(BaseGraphStorage):
|
|
810 |
v = record[k]
|
811 |
# agtype comes back '{key: value}::type' which must be parsed
|
812 |
if isinstance(v, str) and "::" in v:
|
813 |
-
|
814 |
-
|
815 |
-
|
816 |
-
|
817 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
818 |
|
819 |
# iterate returned fields and parse appropriately
|
820 |
for k in record.keys():
|
821 |
v = record[k]
|
822 |
if isinstance(v, str) and "::" in v:
|
823 |
-
|
824 |
-
|
825 |
-
|
826 |
-
|
827 |
-
|
828 |
-
|
829 |
-
|
830 |
-
|
831 |
-
|
832 |
-
|
833 |
-
|
834 |
-
|
835 |
-
|
836 |
-
|
837 |
-
|
838 |
-
|
839 |
-
|
840 |
-
|
841 |
-
|
842 |
-
|
843 |
-
|
844 |
-
|
845 |
-
|
846 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
847 |
else:
|
848 |
-
|
|
|
|
|
|
|
849 |
|
850 |
return d
|
851 |
|
@@ -1319,7 +1362,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1319 |
OPTIONAL MATCH p = (n)-[*..%d]-(m)
|
1320 |
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
1321 |
LIMIT %d
|
1322 |
-
$$) AS (nodes agtype
|
1323 |
self.graph_name,
|
1324 |
encoded_node_label,
|
1325 |
max_depth,
|
@@ -1328,17 +1371,23 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1328 |
|
1329 |
results = await self._query(query)
|
1330 |
|
1331 |
-
nodes =
|
1332 |
edges = []
|
|
|
1333 |
|
1334 |
for result in results:
|
1335 |
if node_label == "*":
|
1336 |
if result["n"]:
|
1337 |
node = result["n"]
|
1338 |
-
|
|
|
|
|
|
|
1339 |
if result["m"]:
|
1340 |
node = result["m"]
|
1341 |
-
|
|
|
|
|
1342 |
if result["r"]:
|
1343 |
edge = result["r"]
|
1344 |
src_id = self._decode_graph_label(edge["start_id"])
|
@@ -1347,16 +1396,36 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1347 |
else:
|
1348 |
if result["nodes"]:
|
1349 |
for node in result["nodes"]:
|
1350 |
-
|
|
|
|
|
|
|
1351 |
if result["relationships"]:
|
1352 |
-
for edge in result["relationships"]:
|
1353 |
-
src_id = self._decode_graph_label(edge["
|
1354 |
-
tgt_id = self._decode_graph_label(edge["
|
1355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1356 |
|
1357 |
kg = KnowledgeGraph(
|
1358 |
-
nodes=[
|
1359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1360 |
)
|
1361 |
|
1362 |
return kg
|
|
|
810 |
v = record[k]
|
811 |
# agtype comes back '{key: value}::type' which must be parsed
|
812 |
if isinstance(v, str) and "::" in v:
|
813 |
+
if v.startswith("[") and v.endswith("]"):
|
814 |
+
if "::vertex" not in v:
|
815 |
+
continue
|
816 |
+
v = v.replace("::vertex", "")
|
817 |
+
vertexes = json.loads(v)
|
818 |
+
for vertex in vertexes:
|
819 |
+
vertices[vertex["id"]] = vertex.get("properties")
|
820 |
+
else:
|
821 |
+
dtype = v.split("::")[-1]
|
822 |
+
v = v.split("::")[0]
|
823 |
+
if dtype == "vertex":
|
824 |
+
vertex = json.loads(v)
|
825 |
+
vertices[vertex["id"]] = vertex.get("properties")
|
826 |
|
827 |
# iterate returned fields and parse appropriately
|
828 |
for k in record.keys():
|
829 |
v = record[k]
|
830 |
if isinstance(v, str) and "::" in v:
|
831 |
+
if v.startswith("[") and v.endswith("]"):
|
832 |
+
if "::vertex" in v:
|
833 |
+
v = v.replace("::vertex", "")
|
834 |
+
vertexes = json.loads(v)
|
835 |
+
dl = []
|
836 |
+
for vertex in vertexes:
|
837 |
+
prop = vertex.get("properties")
|
838 |
+
if not prop:
|
839 |
+
prop = {}
|
840 |
+
prop["label"] = PGGraphStorage._decode_graph_label(
|
841 |
+
prop["node_id"]
|
842 |
+
)
|
843 |
+
dl.append(prop)
|
844 |
+
d[k] = dl
|
845 |
+
|
846 |
+
elif "::edge" in v:
|
847 |
+
v = v.replace("::edge", "")
|
848 |
+
edges = json.loads(v)
|
849 |
+
dl = []
|
850 |
+
for edge in edges:
|
851 |
+
dl.append(
|
852 |
+
(
|
853 |
+
vertices[edge["start_id"]],
|
854 |
+
edge["label"],
|
855 |
+
vertices[edge["end_id"]],
|
856 |
+
)
|
857 |
+
)
|
858 |
+
d[k] = dl
|
859 |
+
else:
|
860 |
+
print("WARNING: unsupported type")
|
861 |
+
continue
|
862 |
+
|
863 |
+
else:
|
864 |
+
dtype = v.split("::")[-1]
|
865 |
+
v = v.split("::")[0]
|
866 |
+
if dtype == "vertex":
|
867 |
+
vertex = json.loads(v)
|
868 |
+
field = vertex.get("properties")
|
869 |
+
if not field:
|
870 |
+
field = {}
|
871 |
+
field["label"] = PGGraphStorage._decode_graph_label(
|
872 |
+
field["node_id"]
|
873 |
+
)
|
874 |
+
d[k] = field
|
875 |
+
# convert edge from id-label->id by replacing id with node information
|
876 |
+
# we only do this if the vertex was also returned in the query
|
877 |
+
# this is an attempt to be consistent with neo4j implementation
|
878 |
+
elif dtype == "edge":
|
879 |
+
edge = json.loads(v)
|
880 |
+
d[k] = (
|
881 |
+
vertices.get(edge["start_id"], {}),
|
882 |
+
edge[
|
883 |
+
"label"
|
884 |
+
], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
|
885 |
+
vertices.get(edge["end_id"], {}),
|
886 |
+
)
|
887 |
else:
|
888 |
+
if v is None or (v.count("{") < 1 and v.count("[") < 1):
|
889 |
+
d[k] = v
|
890 |
+
else:
|
891 |
+
d[k] = json.loads(v) if isinstance(v, str) else v
|
892 |
|
893 |
return d
|
894 |
|
|
|
1362 |
OPTIONAL MATCH p = (n)-[*..%d]-(m)
|
1363 |
RETURN nodes(p) AS nodes, relationships(p) AS relationships
|
1364 |
LIMIT %d
|
1365 |
+
$$) AS (nodes agtype, relationships agtype)""" % (
|
1366 |
self.graph_name,
|
1367 |
encoded_node_label,
|
1368 |
max_depth,
|
|
|
1371 |
|
1372 |
results = await self._query(query)
|
1373 |
|
1374 |
+
nodes = {}
|
1375 |
edges = []
|
1376 |
+
unique_edge_ids = set()
|
1377 |
|
1378 |
for result in results:
|
1379 |
if node_label == "*":
|
1380 |
if result["n"]:
|
1381 |
node = result["n"]
|
1382 |
+
node_id = self._decode_graph_label(node["node_id"])
|
1383 |
+
if node_id not in nodes:
|
1384 |
+
nodes[node_id] = node
|
1385 |
+
|
1386 |
if result["m"]:
|
1387 |
node = result["m"]
|
1388 |
+
node_id = self._decode_graph_label(node["node_id"])
|
1389 |
+
if node_id not in nodes:
|
1390 |
+
nodes[node_id] = node
|
1391 |
if result["r"]:
|
1392 |
edge = result["r"]
|
1393 |
src_id = self._decode_graph_label(edge["start_id"])
|
|
|
1396 |
else:
|
1397 |
if result["nodes"]:
|
1398 |
for node in result["nodes"]:
|
1399 |
+
node_id = self._decode_graph_label(node["node_id"])
|
1400 |
+
if node_id not in nodes:
|
1401 |
+
nodes[node_id] = node
|
1402 |
+
|
1403 |
if result["relationships"]:
|
1404 |
+
for edge in result["relationships"]: # src --DIRECTED--> target
|
1405 |
+
src_id = self._decode_graph_label(edge[0]["node_id"])
|
1406 |
+
tgt_id = self._decode_graph_label(edge[2]["node_id"])
|
1407 |
+
id = src_id + "," + tgt_id
|
1408 |
+
if id in unique_edge_ids:
|
1409 |
+
continue
|
1410 |
+
else:
|
1411 |
+
unique_edge_ids.add(id)
|
1412 |
+
edges.append(
|
1413 |
+
(id, src_id, tgt_id, {"source": edge[0], "target": edge[2]})
|
1414 |
+
)
|
1415 |
|
1416 |
kg = KnowledgeGraph(
|
1417 |
+
nodes=[
|
1418 |
+
KnowledgeGraphNode(
|
1419 |
+
id=node_id, labels=[node_id], properties=nodes[node_id]
|
1420 |
+
)
|
1421 |
+
for node_id in nodes
|
1422 |
+
],
|
1423 |
+
edges=[
|
1424 |
+
KnowledgeGraphEdge(
|
1425 |
+
id=id, type="DIRECTED", source=src, target=tgt, properties=props
|
1426 |
+
)
|
1427 |
+
for id, src, tgt, props in edges
|
1428 |
+
],
|
1429 |
)
|
1430 |
|
1431 |
return kg
|