Commit
·
5876f53
1
Parent(s):
359e407
Fix the bug of AGE processing
Browse files
examples/lightrag_zhipu_postgres_demo.py
CHANGED
@@ -6,7 +6,7 @@ import time
|
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
from lightrag import LightRAG, QueryParam
|
9 |
-
from lightrag.kg.postgres_impl import PostgreSQLDB
|
10 |
from lightrag.llm import ollama_embedding, zhipu_complete
|
11 |
from lightrag.utils import EmbeddingFunc
|
12 |
|
@@ -67,7 +67,6 @@ async def main():
|
|
67 |
rag.entities_vdb.db = postgres_db
|
68 |
rag.graph_storage_cls.db = postgres_db
|
69 |
rag.chunk_entity_relation_graph.db = postgres_db
|
70 |
-
await rag.chunk_entity_relation_graph.check_graph_exists()
|
71 |
# add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
|
72 |
rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
|
73 |
|
@@ -103,21 +102,6 @@ async def main():
|
|
103 |
)
|
104 |
print(f"Hybrid Query Time: {time.time() - start_time} seconds")
|
105 |
|
106 |
-
print("**** Start Stream Query ****")
|
107 |
-
start_time = time.time()
|
108 |
-
# stream response
|
109 |
-
resp = await rag.aquery(
|
110 |
-
"What are the top themes in this story?",
|
111 |
-
param=QueryParam(mode="hybrid", stream=True),
|
112 |
-
)
|
113 |
-
print(f"Stream Query Time: {time.time() - start_time} seconds")
|
114 |
-
print("**** Done Stream Query ****")
|
115 |
-
|
116 |
-
if inspect.isasyncgen(resp):
|
117 |
-
asyncio.run(print_stream(resp))
|
118 |
-
else:
|
119 |
-
print(resp)
|
120 |
-
|
121 |
|
122 |
if __name__ == "__main__":
|
123 |
asyncio.run(main())
|
|
|
6 |
from dotenv import load_dotenv
|
7 |
|
8 |
from lightrag import LightRAG, QueryParam
|
9 |
+
from lightrag.kg.postgres_impl import PostgreSQLDB
|
10 |
from lightrag.llm import ollama_embedding, zhipu_complete
|
11 |
from lightrag.utils import EmbeddingFunc
|
12 |
|
|
|
67 |
rag.entities_vdb.db = postgres_db
|
68 |
rag.graph_storage_cls.db = postgres_db
|
69 |
rag.chunk_entity_relation_graph.db = postgres_db
|
|
|
70 |
# add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
|
71 |
rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
|
72 |
|
|
|
102 |
)
|
103 |
print(f"Hybrid Query Time: {time.time() - start_time} seconds")
|
104 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
if __name__ == "__main__":
|
107 |
asyncio.run(main())
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -81,12 +81,12 @@ class PostgreSQLDB:
|
|
81 |
|
82 |
|
83 |
async def query(
|
84 |
-
self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False
|
85 |
) -> Union[dict, None, list[dict]]:
|
86 |
async with self.pool.acquire() as connection:
|
87 |
try:
|
88 |
if for_age:
|
89 |
-
await
|
90 |
if params:
|
91 |
rows = await connection.fetch(sql, *params.values())
|
92 |
else:
|
@@ -95,10 +95,7 @@ class PostgreSQLDB:
|
|
95 |
if multirows:
|
96 |
if rows:
|
97 |
columns = [col for col in rows[0].keys()]
|
98 |
-
# print("columns", columns.__class__, columns)
|
99 |
-
# print("rows", rows)
|
100 |
data = [dict(zip(columns, row)) for row in rows]
|
101 |
-
# print("data", data)
|
102 |
else:
|
103 |
data = []
|
104 |
else:
|
@@ -114,11 +111,11 @@ class PostgreSQLDB:
|
|
114 |
print(params)
|
115 |
raise
|
116 |
|
117 |
-
async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False):
|
118 |
try:
|
119 |
async with self.pool.acquire() as connection:
|
120 |
if for_age:
|
121 |
-
await
|
122 |
|
123 |
if data is None:
|
124 |
await connection.execute(sql)
|
@@ -130,6 +127,14 @@ class PostgreSQLDB:
|
|
130 |
print(data)
|
131 |
raise
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
@dataclass
|
135 |
class PGKVStorage(BaseKVStorage):
|
@@ -346,18 +351,14 @@ class PGVectorStorage(BaseVectorStorage):
|
|
346 |
embeddings = await self.embedding_func([query])
|
347 |
embedding = embeddings[0]
|
348 |
embedding_string = ",".join(map(str, embedding))
|
349 |
-
# print("Namespace", self.namespace)
|
350 |
|
351 |
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
|
352 |
-
# print("sql is: ", sql)
|
353 |
params = {
|
354 |
"workspace": self.db.workspace,
|
355 |
"better_than_threshold": self.cosine_better_than_threshold,
|
356 |
"top_k": top_k,
|
357 |
}
|
358 |
-
# print("params is: ", params)
|
359 |
results = await self.db.query(sql, params=params, multirows=True)
|
360 |
-
print("vector search result:", results)
|
361 |
return results
|
362 |
|
363 |
@dataclass
|
@@ -487,19 +488,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|
487 |
async def index_done_callback(self):
|
488 |
print("KG successfully indexed.")
|
489 |
|
490 |
-
async def check_graph_exists(self):
|
491 |
-
try:
|
492 |
-
res = await self.db.query(f"SELECT * FROM ag_catalog.ag_graph WHERE name = '{self.graph_name}'")
|
493 |
-
if res:
|
494 |
-
logger.info(f"Graph {self.graph_name} exists.")
|
495 |
-
else:
|
496 |
-
logger.info(f"Graph {self.graph_name} does not exist. Creating...")
|
497 |
-
await self.db.execute(f"SELECT create_graph('{self.graph_name}')", for_age=True)
|
498 |
-
logger.info(f"Graph {self.graph_name} created.")
|
499 |
-
except Exception as e:
|
500 |
-
logger.info(f"Failed to check/create graph {self.graph_name}:", e)
|
501 |
-
raise e
|
502 |
-
|
503 |
@staticmethod
|
504 |
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
|
505 |
"""
|
@@ -572,7 +560,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
572 |
|
573 |
Args:
|
574 |
properties (Dict[str,str]): a dictionary containing node/edge properties
|
575 |
-
|
576 |
|
577 |
Returns:
|
578 |
str: the properties dictionary as a properly formatted string
|
@@ -591,7 +579,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
591 |
@staticmethod
|
592 |
def _encode_graph_label(label: str) -> str:
|
593 |
"""
|
594 |
-
Since AGE
|
595 |
|
596 |
Args:
|
597 |
label (str): the original label
|
@@ -604,7 +592,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
604 |
@staticmethod
|
605 |
def _decode_graph_label(encoded_label: str) -> str:
|
606 |
"""
|
607 |
-
Since AGE
|
608 |
|
609 |
Args:
|
610 |
encoded_label (str): the encoded label
|
@@ -656,8 +644,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|
656 |
|
657 |
# pgsql template
|
658 |
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
|
659 |
-
|
660 |
-
$$) AS ({fields})
|
661 |
|
662 |
# if there are any returned fields they must be added to the pgsql query
|
663 |
if "return" in query.lower():
|
@@ -702,7 +690,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
702 |
projection=select_str,
|
703 |
)
|
704 |
|
705 |
-
async def _query(self, query: str, readonly=True, **params: str) -> List[Dict[str, Any]]:
|
706 |
"""
|
707 |
Query the graph by taking a cypher query, converting it to an
|
708 |
age compatible query, executing it and converting the result
|
@@ -720,9 +708,14 @@ class PGGraphStorage(BaseGraphStorage):
|
|
720 |
# execute the query, rolling back on an error
|
721 |
try:
|
722 |
if readonly:
|
723 |
-
data = await self.db.query(wrapped_query, multirows=True, for_age=True)
|
724 |
else:
|
725 |
-
|
|
|
|
|
|
|
|
|
|
|
726 |
except Exception as e:
|
727 |
raise PGGraphQueryException(
|
728 |
{
|
@@ -743,9 +736,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
743 |
async def has_node(self, node_id: str) -> bool:
|
744 |
entity_name_label = node_id.strip('"')
|
745 |
|
746 |
-
query = """
|
747 |
-
MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists
|
748 |
-
"""
|
749 |
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
|
750 |
single_result = (await self._query(query, **params))[0]
|
751 |
logger.debug(
|
@@ -761,10 +752,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|
761 |
entity_name_label_source = source_node_id.strip('"')
|
762 |
entity_name_label_target = target_node_id.strip('"')
|
763 |
|
764 |
-
query = """
|
765 |
-
|
766 |
-
RETURN COUNT(r) > 0 AS edge_exists
|
767 |
-
"""
|
768 |
params = {
|
769 |
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
|
770 |
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
|
@@ -780,9 +769,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
780 |
|
781 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
782 |
entity_name_label = node_id.strip('"')
|
783 |
-
query = """
|
784 |
-
MATCH (n:`{label}`) RETURN n
|
785 |
-
"""
|
786 |
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
|
787 |
record = await self._query(query, **params)
|
788 |
if record:
|
@@ -800,10 +787,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
800 |
async def node_degree(self, node_id: str) -> int:
|
801 |
entity_name_label = node_id.strip('"')
|
802 |
|
803 |
-
query = """
|
804 |
-
MATCH (n:`{label}`)-[]->(x)
|
805 |
-
RETURN count(x) AS total_edge_count
|
806 |
-
"""
|
807 |
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
|
808 |
record = (await self._query(query, **params))[0]
|
809 |
if record:
|
@@ -841,8 +825,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|
841 |
Find all edges between nodes of two given labels
|
842 |
|
843 |
Args:
|
844 |
-
|
845 |
-
|
846 |
|
847 |
Returns:
|
848 |
list: List of all relationships/edges found
|
@@ -850,11 +834,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|
850 |
entity_name_label_source = source_node_id.strip('"')
|
851 |
entity_name_label_target = target_node_id.strip('"')
|
852 |
|
853 |
-
query = """
|
854 |
-
MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
|
855 |
RETURN properties(r) as edge_properties
|
856 |
-
LIMIT 1
|
857 |
-
"""
|
858 |
params = {
|
859 |
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
|
860 |
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
|
@@ -877,11 +859,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|
877 |
"""
|
878 |
node_label = source_node_id.strip('"')
|
879 |
|
880 |
-
query = """
|
881 |
-
MATCH (n:`{label}`)
|
882 |
OPTIONAL MATCH (n)-[r]-(connected)
|
883 |
-
RETURN n, r, connected
|
884 |
-
"""
|
885 |
params = {"label": PGGraphStorage._encode_graph_label(node_label)}
|
886 |
results = await self._query(query, **params)
|
887 |
edges = []
|
@@ -919,10 +899,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|
919 |
label = node_id.strip('"')
|
920 |
properties = node_data
|
921 |
|
922 |
-
query = """
|
923 |
-
|
924 |
-
SET n += {properties}
|
925 |
-
"""
|
926 |
params = {
|
927 |
"label": PGGraphStorage._encode_graph_label(label),
|
928 |
"properties": PGGraphStorage._format_properties(properties),
|
@@ -957,22 +935,22 @@ class PGGraphStorage(BaseGraphStorage):
|
|
957 |
source_node_label = source_node_id.strip('"')
|
958 |
target_node_label = target_node_id.strip('"')
|
959 |
edge_properties = edge_data
|
|
|
960 |
|
961 |
-
query = """
|
962 |
-
MATCH (source:`{src_label}`)
|
963 |
WITH source
|
964 |
MATCH (target:`{tgt_label}`)
|
965 |
MERGE (source)-[r:DIRECTED]->(target)
|
966 |
SET r += {properties}
|
967 |
-
RETURN r
|
968 |
-
"""
|
969 |
params = {
|
970 |
"src_label": PGGraphStorage._encode_graph_label(source_node_label),
|
971 |
"tgt_label": PGGraphStorage._encode_graph_label(target_node_label),
|
972 |
"properties": PGGraphStorage._format_properties(edge_properties),
|
973 |
}
|
|
|
974 |
try:
|
975 |
-
await self._query(query, readonly=False, **params)
|
976 |
logger.debug(
|
977 |
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
|
978 |
source_node_label,
|
@@ -1127,7 +1105,7 @@ SQL_TEMPLATES = {
|
|
1127 |
updatetime = CURRENT_TIMESTAMP
|
1128 |
""",
|
1129 |
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
|
1130 |
-
VALUES ($1, $2, $3, $4, $5
|
1131 |
ON CONFLICT (workspace,id) DO UPDATE
|
1132 |
SET entity_name=EXCLUDED.entity_name,
|
1133 |
content=EXCLUDED.content,
|
|
|
81 |
|
82 |
|
83 |
async def query(
|
84 |
+
self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False, graph_name: str = None
|
85 |
) -> Union[dict, None, list[dict]]:
|
86 |
async with self.pool.acquire() as connection:
|
87 |
try:
|
88 |
if for_age:
|
89 |
+
await PostgreSQLDB._prerequisite(connection, graph_name)
|
90 |
if params:
|
91 |
rows = await connection.fetch(sql, *params.values())
|
92 |
else:
|
|
|
95 |
if multirows:
|
96 |
if rows:
|
97 |
columns = [col for col in rows[0].keys()]
|
|
|
|
|
98 |
data = [dict(zip(columns, row)) for row in rows]
|
|
|
99 |
else:
|
100 |
data = []
|
101 |
else:
|
|
|
111 |
print(params)
|
112 |
raise
|
113 |
|
114 |
+
async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False, graph_name: str = None):
|
115 |
try:
|
116 |
async with self.pool.acquire() as connection:
|
117 |
if for_age:
|
118 |
+
await PostgreSQLDB._prerequisite(connection, graph_name)
|
119 |
|
120 |
if data is None:
|
121 |
await connection.execute(sql)
|
|
|
127 |
print(data)
|
128 |
raise
|
129 |
|
130 |
+
@staticmethod
|
131 |
+
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
|
132 |
+
try:
|
133 |
+
await conn.execute(f'SET search_path = ag_catalog, "$user", public')
|
134 |
+
await conn.execute(f"""select create_graph('{graph_name}')""")
|
135 |
+
except asyncpg.exceptions.InvalidSchemaNameError:
|
136 |
+
pass
|
137 |
+
|
138 |
|
139 |
@dataclass
|
140 |
class PGKVStorage(BaseKVStorage):
|
|
|
351 |
embeddings = await self.embedding_func([query])
|
352 |
embedding = embeddings[0]
|
353 |
embedding_string = ",".join(map(str, embedding))
|
|
|
354 |
|
355 |
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
|
|
|
356 |
params = {
|
357 |
"workspace": self.db.workspace,
|
358 |
"better_than_threshold": self.cosine_better_than_threshold,
|
359 |
"top_k": top_k,
|
360 |
}
|
|
|
361 |
results = await self.db.query(sql, params=params, multirows=True)
|
|
|
362 |
return results
|
363 |
|
364 |
@dataclass
|
|
|
488 |
async def index_done_callback(self):
|
489 |
print("KG successfully indexed.")
|
490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
491 |
@staticmethod
|
492 |
def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
|
493 |
"""
|
|
|
560 |
|
561 |
Args:
|
562 |
properties (Dict[str,str]): a dictionary containing node/edge properties
|
563 |
+
_id (Union[str, None]): the id of the node or None if none exists
|
564 |
|
565 |
Returns:
|
566 |
str: the properties dictionary as a properly formatted string
|
|
|
579 |
@staticmethod
|
580 |
def _encode_graph_label(label: str) -> str:
|
581 |
"""
|
582 |
+
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
|
583 |
|
584 |
Args:
|
585 |
label (str): the original label
|
|
|
592 |
@staticmethod
|
593 |
def _decode_graph_label(encoded_label: str) -> str:
|
594 |
"""
|
595 |
+
Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
|
596 |
|
597 |
Args:
|
598 |
encoded_label (str): the encoded label
|
|
|
644 |
|
645 |
# pgsql template
|
646 |
template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
|
647 |
+
{query}
|
648 |
+
$$) AS ({fields})"""
|
649 |
|
650 |
# if there are any returned fields they must be added to the pgsql query
|
651 |
if "return" in query.lower():
|
|
|
690 |
projection=select_str,
|
691 |
)
|
692 |
|
693 |
+
async def _query(self, query: str, readonly=True, upsert_edge=False, **params: str) -> List[Dict[str, Any]]:
|
694 |
"""
|
695 |
Query the graph by taking a cypher query, converting it to an
|
696 |
age compatible query, executing it and converting the result
|
|
|
708 |
# execute the query, rolling back on an error
|
709 |
try:
|
710 |
if readonly:
|
711 |
+
data = await self.db.query(wrapped_query, multirows=True, for_age=True, graph_name=self.graph_name)
|
712 |
else:
|
713 |
+
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
|
714 |
+
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
|
715 |
+
if upsert_edge:
|
716 |
+
data = await self.db.execute(f"{wrapped_query};{wrapped_query};", for_age=True, graph_name=self.graph_name)
|
717 |
+
else:
|
718 |
+
data = await self.db.execute(wrapped_query, for_age=True, graph_name=self.graph_name)
|
719 |
except Exception as e:
|
720 |
raise PGGraphQueryException(
|
721 |
{
|
|
|
736 |
async def has_node(self, node_id: str) -> bool:
|
737 |
entity_name_label = node_id.strip('"')
|
738 |
|
739 |
+
query = """MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"""
|
|
|
|
|
740 |
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
|
741 |
single_result = (await self._query(query, **params))[0]
|
742 |
logger.debug(
|
|
|
752 |
entity_name_label_source = source_node_id.strip('"')
|
753 |
entity_name_label_target = target_node_id.strip('"')
|
754 |
|
755 |
+
query = """MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
|
756 |
+
RETURN COUNT(r) > 0 AS edge_exists"""
|
|
|
|
|
757 |
params = {
|
758 |
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
|
759 |
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
|
|
|
769 |
|
770 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
771 |
entity_name_label = node_id.strip('"')
|
772 |
+
query = """MATCH (n:`{label}`) RETURN n"""
|
|
|
|
|
773 |
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
|
774 |
record = await self._query(query, **params)
|
775 |
if record:
|
|
|
787 |
async def node_degree(self, node_id: str) -> int:
|
788 |
entity_name_label = node_id.strip('"')
|
789 |
|
790 |
+
query = """MATCH (n:`{label}`)-[]->(x) RETURN count(x) AS total_edge_count"""
|
|
|
|
|
|
|
791 |
params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
|
792 |
record = (await self._query(query, **params))[0]
|
793 |
if record:
|
|
|
825 |
Find all edges between nodes of two given labels
|
826 |
|
827 |
Args:
|
828 |
+
source_node_id (str): Label of the source nodes
|
829 |
+
target_node_id (str): Label of the target nodes
|
830 |
|
831 |
Returns:
|
832 |
list: List of all relationships/edges found
|
|
|
834 |
entity_name_label_source = source_node_id.strip('"')
|
835 |
entity_name_label_target = target_node_id.strip('"')
|
836 |
|
837 |
+
query = """MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
|
|
|
838 |
RETURN properties(r) as edge_properties
|
839 |
+
LIMIT 1"""
|
|
|
840 |
params = {
|
841 |
"src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
|
842 |
"tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
|
|
|
859 |
"""
|
860 |
node_label = source_node_id.strip('"')
|
861 |
|
862 |
+
query = """MATCH (n:`{label}`)
|
|
|
863 |
OPTIONAL MATCH (n)-[r]-(connected)
|
864 |
+
RETURN n, r, connected"""
|
|
|
865 |
params = {"label": PGGraphStorage._encode_graph_label(node_label)}
|
866 |
results = await self._query(query, **params)
|
867 |
edges = []
|
|
|
899 |
label = node_id.strip('"')
|
900 |
properties = node_data
|
901 |
|
902 |
+
query = """MERGE (n:`{label}`)
|
903 |
+
SET n += {properties}"""
|
|
|
|
|
904 |
params = {
|
905 |
"label": PGGraphStorage._encode_graph_label(label),
|
906 |
"properties": PGGraphStorage._format_properties(properties),
|
|
|
935 |
source_node_label = source_node_id.strip('"')
|
936 |
target_node_label = target_node_id.strip('"')
|
937 |
edge_properties = edge_data
|
938 |
+
logger.info(f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}")
|
939 |
|
940 |
+
query = """MATCH (source:`{src_label}`)
|
|
|
941 |
WITH source
|
942 |
MATCH (target:`{tgt_label}`)
|
943 |
MERGE (source)-[r:DIRECTED]->(target)
|
944 |
SET r += {properties}
|
945 |
+
RETURN r"""
|
|
|
946 |
params = {
|
947 |
"src_label": PGGraphStorage._encode_graph_label(source_node_label),
|
948 |
"tgt_label": PGGraphStorage._encode_graph_label(target_node_label),
|
949 |
"properties": PGGraphStorage._format_properties(edge_properties),
|
950 |
}
|
951 |
+
# logger.info(f"-- inserting edge after formatted: {params}")
|
952 |
try:
|
953 |
+
await self._query(query, readonly=False, upsert_edge=True, **params)
|
954 |
logger.debug(
|
955 |
"Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
|
956 |
source_node_label,
|
|
|
1105 |
updatetime = CURRENT_TIMESTAMP
|
1106 |
""",
|
1107 |
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
|
1108 |
+
VALUES ($1, $2, $3, $4, $5)
|
1109 |
ON CONFLICT (workspace,id) DO UPDATE
|
1110 |
SET entity_name=EXCLUDED.entity_name,
|
1111 |
content=EXCLUDED.content,
|
lightrag/kg/postgres_impl_test.py
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import asyncpg
|
3 |
+
import sys, os
|
4 |
+
|
5 |
+
import psycopg
|
6 |
+
from psycopg_pool import AsyncConnectionPool
|
7 |
+
from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
|
8 |
+
|
9 |
+
DB="rag"
|
10 |
+
USER="rag"
|
11 |
+
PASSWORD="rag"
|
12 |
+
HOST="localhost"
|
13 |
+
PORT="15432"
|
14 |
+
os.environ["AGE_GRAPH_NAME"] = "dickens"
|
15 |
+
|
16 |
+
if sys.platform.startswith("win"):
|
17 |
+
import asyncio.windows_events
|
18 |
+
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
19 |
+
|
20 |
+
async def get_pool():
|
21 |
+
return await asyncpg.create_pool(
|
22 |
+
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
|
23 |
+
min_size=10, # 连接池初始化时默认的最小连接数, 默认为1 0
|
24 |
+
max_size=10, # 连接池的最大连接数, 默认为 10
|
25 |
+
max_queries=5000, # 每个链接最大查询数量, 超过了就换新的连接, 默认 5000
|
26 |
+
# 最大不活跃时间, 默认 300.0, 超过这个时间的连接就会被关闭, 传入 0 的话则永不关闭
|
27 |
+
max_inactive_connection_lifetime=300.0
|
28 |
+
)
|
29 |
+
|
30 |
+
async def main1():
|
31 |
+
connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
|
32 |
+
pool = AsyncConnectionPool(connection_string, open=False)
|
33 |
+
await pool.open()
|
34 |
+
|
35 |
+
try:
|
36 |
+
conn = await pool.getconn(timeout=10)
|
37 |
+
async with conn.cursor() as curs:
|
38 |
+
try:
|
39 |
+
await curs.execute('SET search_path = ag_catalog, "$user", public')
|
40 |
+
await curs.execute(f"SELECT create_graph('dickens-2')")
|
41 |
+
await conn.commit()
|
42 |
+
print("create_graph success")
|
43 |
+
except (
|
44 |
+
psycopg.errors.InvalidSchemaName,
|
45 |
+
psycopg.errors.UniqueViolation,
|
46 |
+
):
|
47 |
+
print("create_graph already exists")
|
48 |
+
await conn.rollback()
|
49 |
+
finally:
|
50 |
+
pass
|
51 |
+
|
52 |
+
db = PostgreSQLDB(
|
53 |
+
config={
|
54 |
+
"host": "localhost",
|
55 |
+
"port": 15432,
|
56 |
+
"user": "rag",
|
57 |
+
"password": "rag",
|
58 |
+
"database": "rag",
|
59 |
+
}
|
60 |
+
)
|
61 |
+
|
62 |
+
async def query_with_age():
|
63 |
+
await db.initdb()
|
64 |
+
graph = PGGraphStorage(
|
65 |
+
namespace="chunk_entity_relation",
|
66 |
+
global_config={},
|
67 |
+
embedding_func=None,
|
68 |
+
)
|
69 |
+
graph.db = db
|
70 |
+
res = await graph.get_node('"CHRISTMAS-TIME"')
|
71 |
+
print("Node is: ", res)
|
72 |
+
|
73 |
+
async def create_edge_with_age():
|
74 |
+
await db.initdb()
|
75 |
+
graph = PGGraphStorage(
|
76 |
+
namespace="chunk_entity_relation",
|
77 |
+
global_config={},
|
78 |
+
embedding_func=None,
|
79 |
+
)
|
80 |
+
graph.db = db
|
81 |
+
await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"})
|
82 |
+
await graph.upsert_node('"THE GIRLS"', {"world": "hello"})
|
83 |
+
await graph.upsert_edge(
|
84 |
+
'"THE CRATCHITS"',
|
85 |
+
'"THE GIRLS"',
|
86 |
+
edge_data={
|
87 |
+
"weight": 7.0,
|
88 |
+
"description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.',
|
89 |
+
"keywords": '"family, collective effort"',
|
90 |
+
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
|
91 |
+
},
|
92 |
+
)
|
93 |
+
res = await graph.get_edge('THE CRATCHITS', '"THE GIRLS"')
|
94 |
+
print("Edge is: ", res)
|
95 |
+
|
96 |
+
|
97 |
+
async def main():
|
98 |
+
pool = await get_pool()
|
99 |
+
# 如果还有其它什么特殊参数,也可以直接往里面传递,因为设置了 **connect_kwargs
|
100 |
+
# 专门用来设置一些数据库独有的某些属性
|
101 |
+
# 从池子中取出一个连接
|
102 |
+
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
|
103 |
+
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
|
104 |
+
async with pool.acquire() as conn:
|
105 |
+
try:
|
106 |
+
await conn.execute("""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""")
|
107 |
+
except asyncpg.exceptions.InvalidSchemaNameError:
|
108 |
+
print("create_graph already exists")
|
109 |
+
# stmt = await conn.prepare(sql)
|
110 |
+
row = await conn.fetch(sql)
|
111 |
+
print("row is: ", row)
|
112 |
+
|
113 |
+
# 解决办法就是起一个别名
|
114 |
+
row = await conn.fetchrow("select '100'::int + 200 as result")
|
115 |
+
print(row) # <Record result=300>
|
116 |
+
# 我们的连接是从池子里面取出的,上下文结束之后会自动放回到到池子里面
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == '__main__':
|
120 |
+
asyncio.run(query_with_age())
|
121 |
+
|
122 |
+
|