Commit
·
d19a515
1
Parent(s):
4d8e9a6
Fix the lint issue
Browse files- examples/lightrag_zhipu_postgres_demo.py +13 -5
- lightrag/kg/postgres_impl.py +100 -57
- lightrag/kg/postgres_impl_test.py +33 -27
- requirements.txt +9 -9
examples/lightrag_zhipu_postgres_demo.py
CHANGED
@@ -53,7 +53,7 @@ async def main():
|
|
53 |
kv_storage="PGKVStorage",
|
54 |
doc_status_storage="PGDocStatusStorage",
|
55 |
graph_storage="PGGraphStorage",
|
56 |
-
vector_storage="PGVectorStorage"
|
57 |
)
|
58 |
# Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
59 |
rag.doc_status.db = postgres_db
|
@@ -77,27 +77,35 @@ async def main():
|
|
77 |
start_time = time.time()
|
78 |
# Perform naive search
|
79 |
print(
|
80 |
-
await rag.aquery(
|
|
|
|
|
81 |
)
|
82 |
print(f"Naive Query Time: {time.time() - start_time} seconds")
|
83 |
# Perform local search
|
84 |
print("**** Start Local Query ****")
|
85 |
start_time = time.time()
|
86 |
print(
|
87 |
-
await rag.aquery(
|
|
|
|
|
88 |
)
|
89 |
print(f"Local Query Time: {time.time() - start_time} seconds")
|
90 |
# Perform global search
|
91 |
print("**** Start Global Query ****")
|
92 |
start_time = time.time()
|
93 |
print(
|
94 |
-
await rag.aquery(
|
|
|
|
|
95 |
)
|
96 |
print(f"Global Query Time: {time.time() - start_time}")
|
97 |
# Perform hybrid search
|
98 |
print("**** Start Hybrid Query ****")
|
99 |
print(
|
100 |
-
await rag.aquery(
|
|
|
|
|
101 |
)
|
102 |
print(f"Hybrid Query Time: {time.time() - start_time} seconds")
|
103 |
|
|
|
53 |
kv_storage="PGKVStorage",
|
54 |
doc_status_storage="PGDocStatusStorage",
|
55 |
graph_storage="PGGraphStorage",
|
56 |
+
vector_storage="PGVectorStorage",
|
57 |
)
|
58 |
# Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool
|
59 |
rag.doc_status.db = postgres_db
|
|
|
77 |
start_time = time.time()
|
78 |
# Perform naive search
|
79 |
print(
|
80 |
+
await rag.aquery(
|
81 |
+
"What are the top themes in this story?", param=QueryParam(mode="naive")
|
82 |
+
)
|
83 |
)
|
84 |
print(f"Naive Query Time: {time.time() - start_time} seconds")
|
85 |
# Perform local search
|
86 |
print("**** Start Local Query ****")
|
87 |
start_time = time.time()
|
88 |
print(
|
89 |
+
await rag.aquery(
|
90 |
+
"What are the top themes in this story?", param=QueryParam(mode="local")
|
91 |
+
)
|
92 |
)
|
93 |
print(f"Local Query Time: {time.time() - start_time} seconds")
|
94 |
# Perform global search
|
95 |
print("**** Start Global Query ****")
|
96 |
start_time = time.time()
|
97 |
print(
|
98 |
+
await rag.aquery(
|
99 |
+
"What are the top themes in this story?", param=QueryParam(mode="global")
|
100 |
+
)
|
101 |
)
|
102 |
print(f"Global Query Time: {time.time() - start_time}")
|
103 |
# Perform hybrid search
|
104 |
print("**** Start Hybrid Query ****")
|
105 |
print(
|
106 |
+
await rag.aquery(
|
107 |
+
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
|
108 |
+
)
|
109 |
)
|
110 |
print(f"Hybrid Query Time: {time.time() - start_time} seconds")
|
111 |
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -19,7 +19,11 @@ from tenacity import (
|
|
19 |
from ..utils import logger
|
20 |
from ..base import (
|
21 |
BaseKVStorage,
|
22 |
-
BaseVectorStorage,
|
|
|
|
|
|
|
|
|
23 |
)
|
24 |
|
25 |
if sys.platform.startswith("win"):
|
@@ -36,14 +40,15 @@ class PostgreSQLDB:
|
|
36 |
self.user = config.get("user", "postgres")
|
37 |
self.password = config.get("password", None)
|
38 |
self.database = config.get("database", "postgres")
|
39 |
-
self.workspace = config.get("workspace",
|
40 |
self.max = 12
|
41 |
self.increment = 1
|
42 |
logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier")
|
43 |
|
44 |
if self.user is None or self.password is None or self.database is None:
|
45 |
-
raise ValueError(
|
46 |
-
|
|
|
47 |
|
48 |
async def initdb(self):
|
49 |
try:
|
@@ -54,12 +59,16 @@ class PostgreSQLDB:
|
|
54 |
host=self.host,
|
55 |
port=self.port,
|
56 |
min_size=1,
|
57 |
-
max_size=self.max
|
58 |
)
|
59 |
|
60 |
-
logger.info(
|
|
|
|
|
61 |
except Exception as e:
|
62 |
-
logger.error(
|
|
|
|
|
63 |
logger.error(f"PostgreSQL database error: {e}")
|
64 |
raise
|
65 |
|
@@ -79,9 +88,13 @@ class PostgreSQLDB:
|
|
79 |
|
80 |
logger.info("Finished checking all tables in PostgreSQL database")
|
81 |
|
82 |
-
|
83 |
async def query(
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
85 |
) -> Union[dict, None, list[dict]]:
|
86 |
async with self.pool.acquire() as connection:
|
87 |
try:
|
@@ -111,7 +124,13 @@ class PostgreSQLDB:
|
|
111 |
print(params)
|
112 |
raise
|
113 |
|
114 |
-
async def execute(
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
try:
|
116 |
async with self.pool.acquire() as connection:
|
117 |
if for_age:
|
@@ -130,7 +149,7 @@ class PostgreSQLDB:
|
|
130 |
@staticmethod
|
131 |
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
|
132 |
try:
|
133 |
-
await conn.execute(
|
134 |
await conn.execute(f"""select create_graph('{graph_name}')""")
|
135 |
except asyncpg.exceptions.InvalidSchemaNameError:
|
136 |
pass
|
@@ -138,7 +157,7 @@ class PostgreSQLDB:
|
|
138 |
|
139 |
@dataclass
|
140 |
class PGKVStorage(BaseKVStorage):
|
141 |
-
db:PostgreSQLDB = None
|
142 |
|
143 |
def __post_init__(self):
|
144 |
self._data = {}
|
@@ -180,7 +199,7 @@ class PGKVStorage(BaseKVStorage):
|
|
180 |
dict_res[mode] = {}
|
181 |
for row in array_res:
|
182 |
dict_res[row["mode"]][row["id"]] = row
|
183 |
-
res = [{k:v} for k,v in dict_res.items()]
|
184 |
else:
|
185 |
res = await self.db.query(sql, params, multirows=True)
|
186 |
if res:
|
@@ -191,7 +210,8 @@ class PGKVStorage(BaseKVStorage):
|
|
191 |
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
192 |
"""Filter out duplicated content"""
|
193 |
sql = SQL_TEMPLATES["filter_keys"].format(
|
194 |
-
table_name=NAMESPACE_TABLE_MAP[self.namespace],
|
|
|
195 |
)
|
196 |
params = {"workspace": self.db.workspace}
|
197 |
try:
|
@@ -207,7 +227,6 @@ class PGKVStorage(BaseKVStorage):
|
|
207 |
print(sql)
|
208 |
print(params)
|
209 |
|
210 |
-
|
211 |
################ INSERT METHODS ################
|
212 |
async def upsert(self, data: Dict[str, dict]):
|
213 |
left_data = {k: v for k, v in data.items() if k not in self._data}
|
@@ -246,7 +265,7 @@ class PGKVStorage(BaseKVStorage):
|
|
246 |
@dataclass
|
247 |
class PGVectorStorage(BaseVectorStorage):
|
248 |
cosine_better_than_threshold: float = 0.2
|
249 |
-
db:PostgreSQLDB = None
|
250 |
|
251 |
def __post_init__(self):
|
252 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
@@ -282,6 +301,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|
282 |
"content_vector": json.dumps(item["__vector__"].tolist()),
|
283 |
}
|
284 |
return upsert_sql, data
|
|
|
285 |
def _upsert_relationships(self, item: dict):
|
286 |
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
287 |
data = {
|
@@ -340,8 +360,6 @@ class PGVectorStorage(BaseVectorStorage):
|
|
340 |
|
341 |
await self.db.execute(upsert_sql, data)
|
342 |
|
343 |
-
|
344 |
-
|
345 |
async def index_done_callback(self):
|
346 |
logger.info("vector data had been saved into postgresql db!")
|
347 |
|
@@ -350,7 +368,7 @@ class PGVectorStorage(BaseVectorStorage):
|
|
350 |
"""从向量数据库中查询数据"""
|
351 |
embeddings = await self.embedding_func([query])
|
352 |
embedding = embeddings[0]
|
353 |
-
embedding_string =
|
354 |
|
355 |
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
|
356 |
params = {
|
@@ -361,10 +379,12 @@ class PGVectorStorage(BaseVectorStorage):
|
|
361 |
results = await self.db.query(sql, params=params, multirows=True)
|
362 |
return results
|
363 |
|
|
|
364 |
@dataclass
|
365 |
class PGDocStatusStorage(DocStatusStorage):
|
366 |
"""PostgreSQL implementation of document status storage"""
|
367 |
-
|
|
|
368 |
|
369 |
def __post_init__(self):
|
370 |
pass
|
@@ -372,41 +392,47 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
372 |
async def filter_keys(self, data: list[str]) -> set[str]:
|
373 |
"""Return keys that don't exist in storage"""
|
374 |
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})"
|
375 |
-
result = await self.db.query(sql, {
|
376 |
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
377 |
if result is None:
|
378 |
return set(data)
|
379 |
else:
|
380 |
-
existed = set([element[
|
381 |
return set(data) - existed
|
382 |
|
383 |
async def get_status_counts(self) -> Dict[str, int]:
|
384 |
"""Get counts of documents in each status"""
|
385 |
-
sql =
|
386 |
FROM LIGHTRAG_DOC_STATUS
|
387 |
where workspace=$1 GROUP BY STATUS
|
388 |
-
|
389 |
-
result = await self.db.query(sql, {
|
390 |
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
|
391 |
counts = {}
|
392 |
for doc in result:
|
393 |
counts[doc["status"]] = doc["count"]
|
394 |
return counts
|
395 |
|
396 |
-
async def get_docs_by_status(
|
|
|
|
|
397 |
"""Get all documents by status"""
|
398 |
-
sql =
|
399 |
-
params = {
|
400 |
result = await self.db.query(sql, params, True)
|
401 |
# Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
|
402 |
# Converting to be a dict
|
403 |
-
return {
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
|
|
|
|
|
|
|
|
410 |
|
411 |
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
|
412 |
"""Get all failed documents"""
|
@@ -436,14 +462,17 @@ class PGDocStatusStorage(DocStatusStorage):
|
|
436 |
updated_at = CURRENT_TIMESTAMP"""
|
437 |
for k, v in data.items():
|
438 |
# chunks_count is optional
|
439 |
-
await self.db.execute(
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
|
|
|
|
|
|
447 |
return data
|
448 |
|
449 |
|
@@ -467,7 +496,7 @@ class PGGraphQueryException(Exception):
|
|
467 |
|
468 |
@dataclass
|
469 |
class PGGraphStorage(BaseGraphStorage):
|
470 |
-
db:PostgreSQLDB = None
|
471 |
|
472 |
@staticmethod
|
473 |
def load_nx_graph(file_name):
|
@@ -484,7 +513,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|
484 |
"node2vec": self._node2vec_embed,
|
485 |
}
|
486 |
|
487 |
-
|
488 |
async def index_done_callback(self):
|
489 |
print("KG successfully indexed.")
|
490 |
|
@@ -552,7 +580,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
552 |
|
553 |
@staticmethod
|
554 |
def _format_properties(
|
555 |
-
|
556 |
) -> str:
|
557 |
"""
|
558 |
Convert a dictionary of properties to a string representation that
|
@@ -669,7 +697,8 @@ class PGGraphStorage(BaseGraphStorage):
|
|
669 |
|
670 |
# get pgsql formatted field names
|
671 |
fields = [
|
672 |
-
PGGraphStorage._get_col_name(field, idx)
|
|
|
673 |
]
|
674 |
|
675 |
# build resulting pgsql relation
|
@@ -690,7 +719,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|
690 |
projection=select_str,
|
691 |
)
|
692 |
|
693 |
-
async def _query(
|
|
|
|
|
694 |
"""
|
695 |
Query the graph by taking a cypher query, converting it to an
|
696 |
age compatible query, executing it and converting the result
|
@@ -708,14 +739,25 @@ class PGGraphStorage(BaseGraphStorage):
|
|
708 |
# execute the query, rolling back on an error
|
709 |
try:
|
710 |
if readonly:
|
711 |
-
data = await self.db.query(
|
|
|
|
|
|
|
|
|
|
|
712 |
else:
|
713 |
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
|
714 |
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
|
715 |
if upsert_edge:
|
716 |
-
data = await self.db.execute(
|
|
|
|
|
|
|
|
|
717 |
else:
|
718 |
-
data = await self.db.execute(
|
|
|
|
|
719 |
except Exception as e:
|
720 |
raise PGGraphQueryException(
|
721 |
{
|
@@ -819,7 +861,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
819 |
return degrees
|
820 |
|
821 |
async def get_edge(
|
822 |
-
|
823 |
) -> Union[dict, None]:
|
824 |
"""
|
825 |
Find all edges between nodes of two given labels
|
@@ -922,7 +964,7 @@ class PGGraphStorage(BaseGraphStorage):
|
|
922 |
retry=retry_if_exception_type((PGGraphQueryException,)),
|
923 |
)
|
924 |
async def upsert_edge(
|
925 |
-
|
926 |
):
|
927 |
"""
|
928 |
Upsert an edge and its properties between two nodes identified by their labels.
|
@@ -935,7 +977,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|
935 |
source_node_label = source_node_id.strip('"')
|
936 |
target_node_label = target_node_id.strip('"')
|
937 |
edge_properties = edge_data
|
938 |
-
logger.info(
|
|
|
|
|
939 |
|
940 |
query = """MATCH (source:`{src_label}`)
|
941 |
WITH source
|
@@ -1056,7 +1100,6 @@ TABLES = {
|
|
1056 |
}
|
1057 |
|
1058 |
|
1059 |
-
|
1060 |
SQL_TEMPLATES = {
|
1061 |
# SQL for KVStorage
|
1062 |
"get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content
|
@@ -1107,7 +1150,7 @@ SQL_TEMPLATES = {
|
|
1107 |
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
|
1108 |
VALUES ($1, $2, $3, $4, $5)
|
1109 |
ON CONFLICT (workspace,id) DO UPDATE
|
1110 |
-
SET entity_name=EXCLUDED.entity_name,
|
1111 |
content=EXCLUDED.content,
|
1112 |
content_vector=EXCLUDED.content_vector,
|
1113 |
updatetime=CURRENT_TIMESTAMP
|
@@ -1136,5 +1179,5 @@ SQL_TEMPLATES = {
|
|
1136 |
(SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
1137 |
FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
|
1138 |
WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
1139 |
-
"""
|
1140 |
}
|
|
|
19 |
from ..utils import logger
|
20 |
from ..base import (
|
21 |
BaseKVStorage,
|
22 |
+
BaseVectorStorage,
|
23 |
+
DocStatusStorage,
|
24 |
+
DocStatus,
|
25 |
+
DocProcessingStatus,
|
26 |
+
BaseGraphStorage,
|
27 |
)
|
28 |
|
29 |
if sys.platform.startswith("win"):
|
|
|
40 |
self.user = config.get("user", "postgres")
|
41 |
self.password = config.get("password", None)
|
42 |
self.database = config.get("database", "postgres")
|
43 |
+
self.workspace = config.get("workspace", "default")
|
44 |
self.max = 12
|
45 |
self.increment = 1
|
46 |
logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier")
|
47 |
|
48 |
if self.user is None or self.password is None or self.database is None:
|
49 |
+
raise ValueError(
|
50 |
+
"Missing database user, password, or database in addon_params"
|
51 |
+
)
|
52 |
|
53 |
async def initdb(self):
|
54 |
try:
|
|
|
59 |
host=self.host,
|
60 |
port=self.port,
|
61 |
min_size=1,
|
62 |
+
max_size=self.max,
|
63 |
)
|
64 |
|
65 |
+
logger.info(
|
66 |
+
f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}"
|
67 |
+
)
|
68 |
except Exception as e:
|
69 |
+
logger.error(
|
70 |
+
f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}"
|
71 |
+
)
|
72 |
logger.error(f"PostgreSQL database error: {e}")
|
73 |
raise
|
74 |
|
|
|
88 |
|
89 |
logger.info("Finished checking all tables in PostgreSQL database")
|
90 |
|
|
|
91 |
async def query(
|
92 |
+
self,
|
93 |
+
sql: str,
|
94 |
+
params: dict = None,
|
95 |
+
multirows: bool = False,
|
96 |
+
for_age: bool = False,
|
97 |
+
graph_name: str = None,
|
98 |
) -> Union[dict, None, list[dict]]:
|
99 |
async with self.pool.acquire() as connection:
|
100 |
try:
|
|
|
124 |
print(params)
|
125 |
raise
|
126 |
|
127 |
+
async def execute(
|
128 |
+
self,
|
129 |
+
sql: str,
|
130 |
+
data: Union[list, dict] = None,
|
131 |
+
for_age: bool = False,
|
132 |
+
graph_name: str = None,
|
133 |
+
):
|
134 |
try:
|
135 |
async with self.pool.acquire() as connection:
|
136 |
if for_age:
|
|
|
149 |
@staticmethod
|
150 |
async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
|
151 |
try:
|
152 |
+
await conn.execute('SET search_path = ag_catalog, "$user", public')
|
153 |
await conn.execute(f"""select create_graph('{graph_name}')""")
|
154 |
except asyncpg.exceptions.InvalidSchemaNameError:
|
155 |
pass
|
|
|
157 |
|
158 |
@dataclass
|
159 |
class PGKVStorage(BaseKVStorage):
|
160 |
+
db: PostgreSQLDB = None
|
161 |
|
162 |
def __post_init__(self):
|
163 |
self._data = {}
|
|
|
199 |
dict_res[mode] = {}
|
200 |
for row in array_res:
|
201 |
dict_res[row["mode"]][row["id"]] = row
|
202 |
+
res = [{k: v} for k, v in dict_res.items()]
|
203 |
else:
|
204 |
res = await self.db.query(sql, params, multirows=True)
|
205 |
if res:
|
|
|
210 |
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
211 |
"""Filter out duplicated content"""
|
212 |
sql = SQL_TEMPLATES["filter_keys"].format(
|
213 |
+
table_name=NAMESPACE_TABLE_MAP[self.namespace],
|
214 |
+
ids=",".join([f"'{id}'" for id in keys]),
|
215 |
)
|
216 |
params = {"workspace": self.db.workspace}
|
217 |
try:
|
|
|
227 |
print(sql)
|
228 |
print(params)
|
229 |
|
|
|
230 |
################ INSERT METHODS ################
|
231 |
async def upsert(self, data: Dict[str, dict]):
|
232 |
left_data = {k: v for k, v in data.items() if k not in self._data}
|
|
|
265 |
@dataclass
|
266 |
class PGVectorStorage(BaseVectorStorage):
|
267 |
cosine_better_than_threshold: float = 0.2
|
268 |
+
db: PostgreSQLDB = None
|
269 |
|
270 |
def __post_init__(self):
|
271 |
self._max_batch_size = self.global_config["embedding_batch_num"]
|
|
|
301 |
"content_vector": json.dumps(item["__vector__"].tolist()),
|
302 |
}
|
303 |
return upsert_sql, data
|
304 |
+
|
305 |
def _upsert_relationships(self, item: dict):
|
306 |
upsert_sql = SQL_TEMPLATES["upsert_relationship"]
|
307 |
data = {
|
|
|
360 |
|
361 |
await self.db.execute(upsert_sql, data)
|
362 |
|
|
|
|
|
363 |
async def index_done_callback(self):
|
364 |
logger.info("vector data had been saved into postgresql db!")
|
365 |
|
|
|
368 |
"""从向量数据库中查询数据"""
|
369 |
embeddings = await self.embedding_func([query])
|
370 |
embedding = embeddings[0]
|
371 |
+
embedding_string = ",".join(map(str, embedding))
|
372 |
|
373 |
sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
|
374 |
params = {
|
|
|
379 |
results = await self.db.query(sql, params=params, multirows=True)
|
380 |
return results
|
381 |
|
382 |
+
|
383 |
@dataclass
|
384 |
class PGDocStatusStorage(DocStatusStorage):
|
385 |
"""PostgreSQL implementation of document status storage"""
|
386 |
+
|
387 |
+
db: PostgreSQLDB = None
|
388 |
|
389 |
def __post_init__(self):
|
390 |
pass
|
|
|
392 |
async def filter_keys(self, data: list[str]) -> set[str]:
|
393 |
"""Return keys that don't exist in storage"""
|
394 |
sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})"
|
395 |
+
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
|
396 |
# The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
|
397 |
if result is None:
|
398 |
return set(data)
|
399 |
else:
|
400 |
+
existed = set([element["id"] for element in result])
|
401 |
return set(data) - existed
|
402 |
|
403 |
async def get_status_counts(self) -> Dict[str, int]:
|
404 |
"""Get counts of documents in each status"""
|
405 |
+
sql = """SELECT status as "status", COUNT(1) as "count"
|
406 |
FROM LIGHTRAG_DOC_STATUS
|
407 |
where workspace=$1 GROUP BY STATUS
|
408 |
+
"""
|
409 |
+
result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
|
410 |
# Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
|
411 |
counts = {}
|
412 |
for doc in result:
|
413 |
counts[doc["status"]] = doc["count"]
|
414 |
return counts
|
415 |
|
416 |
+
async def get_docs_by_status(
|
417 |
+
self, status: DocStatus
|
418 |
+
) -> Dict[str, DocProcessingStatus]:
|
419 |
"""Get all documents by status"""
|
420 |
+
sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1"
|
421 |
+
params = {"workspace": self.db.workspace, "status": status}
|
422 |
result = await self.db.query(sql, params, True)
|
423 |
# Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
|
424 |
# Converting to be a dict
|
425 |
+
return {
|
426 |
+
element["id"]: DocProcessingStatus(
|
427 |
+
content_summary=element["content_summary"],
|
428 |
+
content_length=element["content_length"],
|
429 |
+
status=element["status"],
|
430 |
+
created_at=element["created_at"],
|
431 |
+
updated_at=element["updated_at"],
|
432 |
+
chunks_count=element["chunks_count"],
|
433 |
+
)
|
434 |
+
for element in result
|
435 |
+
}
|
436 |
|
437 |
async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
|
438 |
"""Get all failed documents"""
|
|
|
462 |
updated_at = CURRENT_TIMESTAMP"""
|
463 |
for k, v in data.items():
|
464 |
# chunks_count is optional
|
465 |
+
await self.db.execute(
|
466 |
+
sql,
|
467 |
+
{
|
468 |
+
"workspace": self.db.workspace,
|
469 |
+
"id": k,
|
470 |
+
"content_summary": v["content_summary"],
|
471 |
+
"content_length": v["content_length"],
|
472 |
+
"chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
|
473 |
+
"status": v["status"],
|
474 |
+
},
|
475 |
+
)
|
476 |
return data
|
477 |
|
478 |
|
|
|
496 |
|
497 |
@dataclass
|
498 |
class PGGraphStorage(BaseGraphStorage):
|
499 |
+
db: PostgreSQLDB = None
|
500 |
|
501 |
@staticmethod
|
502 |
def load_nx_graph(file_name):
|
|
|
513 |
"node2vec": self._node2vec_embed,
|
514 |
}
|
515 |
|
|
|
516 |
async def index_done_callback(self):
|
517 |
print("KG successfully indexed.")
|
518 |
|
|
|
580 |
|
581 |
@staticmethod
|
582 |
def _format_properties(
|
583 |
+
properties: Dict[str, Any], _id: Union[str, None] = None
|
584 |
) -> str:
|
585 |
"""
|
586 |
Convert a dictionary of properties to a string representation that
|
|
|
697 |
|
698 |
# get pgsql formatted field names
|
699 |
fields = [
|
700 |
+
PGGraphStorage._get_col_name(field, idx)
|
701 |
+
for idx, field in enumerate(fields)
|
702 |
]
|
703 |
|
704 |
# build resulting pgsql relation
|
|
|
719 |
projection=select_str,
|
720 |
)
|
721 |
|
722 |
+
async def _query(
|
723 |
+
self, query: str, readonly=True, upsert_edge=False, **params: str
|
724 |
+
) -> List[Dict[str, Any]]:
|
725 |
"""
|
726 |
Query the graph by taking a cypher query, converting it to an
|
727 |
age compatible query, executing it and converting the result
|
|
|
739 |
# execute the query, rolling back on an error
|
740 |
try:
|
741 |
if readonly:
|
742 |
+
data = await self.db.query(
|
743 |
+
wrapped_query,
|
744 |
+
multirows=True,
|
745 |
+
for_age=True,
|
746 |
+
graph_name=self.graph_name,
|
747 |
+
)
|
748 |
else:
|
749 |
# for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
|
750 |
# It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
|
751 |
if upsert_edge:
|
752 |
+
data = await self.db.execute(
|
753 |
+
f"{wrapped_query};{wrapped_query};",
|
754 |
+
for_age=True,
|
755 |
+
graph_name=self.graph_name,
|
756 |
+
)
|
757 |
else:
|
758 |
+
data = await self.db.execute(
|
759 |
+
wrapped_query, for_age=True, graph_name=self.graph_name
|
760 |
+
)
|
761 |
except Exception as e:
|
762 |
raise PGGraphQueryException(
|
763 |
{
|
|
|
861 |
return degrees
|
862 |
|
863 |
async def get_edge(
|
864 |
+
self, source_node_id: str, target_node_id: str
|
865 |
) -> Union[dict, None]:
|
866 |
"""
|
867 |
Find all edges between nodes of two given labels
|
|
|
964 |
retry=retry_if_exception_type((PGGraphQueryException,)),
|
965 |
)
|
966 |
async def upsert_edge(
|
967 |
+
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
968 |
):
|
969 |
"""
|
970 |
Upsert an edge and its properties between two nodes identified by their labels.
|
|
|
977 |
source_node_label = source_node_id.strip('"')
|
978 |
target_node_label = target_node_id.strip('"')
|
979 |
edge_properties = edge_data
|
980 |
+
logger.info(
|
981 |
+
f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}"
|
982 |
+
)
|
983 |
|
984 |
query = """MATCH (source:`{src_label}`)
|
985 |
WITH source
|
|
|
1100 |
}
|
1101 |
|
1102 |
|
|
|
1103 |
SQL_TEMPLATES = {
|
1104 |
# SQL for KVStorage
|
1105 |
"get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content
|
|
|
1150 |
"upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
|
1151 |
VALUES ($1, $2, $3, $4, $5)
|
1152 |
ON CONFLICT (workspace,id) DO UPDATE
|
1153 |
+
SET entity_name=EXCLUDED.entity_name,
|
1154 |
content=EXCLUDED.content,
|
1155 |
content_vector=EXCLUDED.content_vector,
|
1156 |
updatetime=CURRENT_TIMESTAMP
|
|
|
1179 |
(SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
|
1180 |
FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
|
1181 |
WHERE distance>$2 ORDER BY distance DESC LIMIT $3
|
1182 |
+
""",
|
1183 |
}
|
lightrag/kg/postgres_impl_test.py
CHANGED
@@ -1,33 +1,39 @@
|
|
1 |
import asyncio
|
2 |
import asyncpg
|
3 |
-
import sys
|
|
|
4 |
|
5 |
import psycopg
|
6 |
from psycopg_pool import AsyncConnectionPool
|
7 |
from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
|
8 |
|
9 |
-
DB="rag"
|
10 |
-
USER="rag"
|
11 |
-
PASSWORD="rag"
|
12 |
-
HOST="localhost"
|
13 |
-
PORT="15432"
|
14 |
os.environ["AGE_GRAPH_NAME"] = "dickens"
|
15 |
|
16 |
if sys.platform.startswith("win"):
|
17 |
import asyncio.windows_events
|
|
|
18 |
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
19 |
|
|
|
20 |
async def get_pool():
|
21 |
return await asyncpg.create_pool(
|
22 |
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
|
23 |
min_size=10,
|
24 |
max_size=10,
|
25 |
max_queries=5000,
|
26 |
-
max_inactive_connection_lifetime=300.0
|
27 |
)
|
28 |
|
|
|
29 |
async def main1():
|
30 |
-
connection_string =
|
|
|
|
|
31 |
pool = AsyncConnectionPool(connection_string, open=False)
|
32 |
await pool.open()
|
33 |
|
@@ -36,18 +42,19 @@ async def main1():
|
|
36 |
async with conn.cursor() as curs:
|
37 |
try:
|
38 |
await curs.execute('SET search_path = ag_catalog, "$user", public')
|
39 |
-
await curs.execute(
|
40 |
await conn.commit()
|
41 |
print("create_graph success")
|
42 |
except (
|
43 |
-
|
44 |
-
|
45 |
):
|
46 |
print("create_graph already exists")
|
47 |
await conn.rollback()
|
48 |
finally:
|
49 |
pass
|
50 |
|
|
|
51 |
db = PostgreSQLDB(
|
52 |
config={
|
53 |
"host": "localhost",
|
@@ -58,6 +65,7 @@ db = PostgreSQLDB(
|
|
58 |
}
|
59 |
)
|
60 |
|
|
|
61 |
async def query_with_age():
|
62 |
await db.initdb()
|
63 |
graph = PGGraphStorage(
|
@@ -69,6 +77,7 @@ async def query_with_age():
|
|
69 |
res = await graph.get_node('"CHRISTMAS-TIME"')
|
70 |
print("Node is: ", res)
|
71 |
|
|
|
72 |
async def create_edge_with_age():
|
73 |
await db.initdb()
|
74 |
graph = PGGraphStorage(
|
@@ -89,31 +98,28 @@ async def create_edge_with_age():
|
|
89 |
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
|
90 |
},
|
91 |
)
|
92 |
-
res = await graph.get_edge(
|
93 |
print("Edge is: ", res)
|
94 |
|
95 |
|
96 |
async def main():
|
97 |
pool = await get_pool()
|
98 |
-
# 如果还有其它什么特殊参数,也可以直接往里面传递,因为设置了 **connect_kwargs
|
99 |
-
# 专门用来设置一些数据库独有的某些属性
|
100 |
-
# 从池子中取出一个连接
|
101 |
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
|
102 |
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
|
103 |
async with pool.acquire() as conn:
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
print(row) # <Record result=300>
|
115 |
-
# 我们的连接是从池子里面取出的,上下文结束之后会自动放回到到池子里面
|
116 |
|
117 |
|
118 |
-
if __name__ ==
|
119 |
asyncio.run(query_with_age())
|
|
|
1 |
import asyncio
|
2 |
import asyncpg
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
|
6 |
import psycopg
|
7 |
from psycopg_pool import AsyncConnectionPool
|
8 |
from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
|
9 |
|
10 |
+
DB = "rag"
|
11 |
+
USER = "rag"
|
12 |
+
PASSWORD = "rag"
|
13 |
+
HOST = "localhost"
|
14 |
+
PORT = "15432"
|
15 |
os.environ["AGE_GRAPH_NAME"] = "dickens"
|
16 |
|
17 |
if sys.platform.startswith("win"):
|
18 |
import asyncio.windows_events
|
19 |
+
|
20 |
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
21 |
|
22 |
+
|
23 |
async def get_pool():
|
24 |
return await asyncpg.create_pool(
|
25 |
f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
|
26 |
min_size=10,
|
27 |
max_size=10,
|
28 |
max_queries=5000,
|
29 |
+
max_inactive_connection_lifetime=300.0,
|
30 |
)
|
31 |
|
32 |
+
|
33 |
async def main1():
|
34 |
+
connection_string = (
|
35 |
+
f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
|
36 |
+
)
|
37 |
pool = AsyncConnectionPool(connection_string, open=False)
|
38 |
await pool.open()
|
39 |
|
|
|
42 |
async with conn.cursor() as curs:
|
43 |
try:
|
44 |
await curs.execute('SET search_path = ag_catalog, "$user", public')
|
45 |
+
await curs.execute("SELECT create_graph('dickens-2')")
|
46 |
await conn.commit()
|
47 |
print("create_graph success")
|
48 |
except (
|
49 |
+
psycopg.errors.InvalidSchemaName,
|
50 |
+
psycopg.errors.UniqueViolation,
|
51 |
):
|
52 |
print("create_graph already exists")
|
53 |
await conn.rollback()
|
54 |
finally:
|
55 |
pass
|
56 |
|
57 |
+
|
58 |
db = PostgreSQLDB(
|
59 |
config={
|
60 |
"host": "localhost",
|
|
|
65 |
}
|
66 |
)
|
67 |
|
68 |
+
|
69 |
async def query_with_age():
|
70 |
await db.initdb()
|
71 |
graph = PGGraphStorage(
|
|
|
77 |
res = await graph.get_node('"CHRISTMAS-TIME"')
|
78 |
print("Node is: ", res)
|
79 |
|
80 |
+
|
81 |
async def create_edge_with_age():
|
82 |
await db.initdb()
|
83 |
graph = PGGraphStorage(
|
|
|
98 |
"source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
|
99 |
},
|
100 |
)
|
101 |
+
res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"')
|
102 |
print("Edge is: ", res)
|
103 |
|
104 |
|
105 |
async def main():
|
106 |
pool = await get_pool()
|
|
|
|
|
|
|
107 |
sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
|
108 |
# cypher = "MATCH (n:how_are_you_doing) RETURN n"
|
109 |
async with pool.acquire() as conn:
|
110 |
+
try:
|
111 |
+
await conn.execute(
|
112 |
+
"""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')"""
|
113 |
+
)
|
114 |
+
except asyncpg.exceptions.InvalidSchemaNameError:
|
115 |
+
print("create_graph already exists")
|
116 |
+
# stmt = await conn.prepare(sql)
|
117 |
+
row = await conn.fetch(sql)
|
118 |
+
print("row is: ", row)
|
119 |
|
120 |
+
row = await conn.fetchrow("select '100'::int + 200 as result")
|
121 |
+
print(row) # <Record result=300>
|
|
|
|
|
122 |
|
123 |
|
124 |
+
if __name__ == "__main__":
|
125 |
asyncio.run(query_with_age())
|
requirements.txt
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
accelerate
|
2 |
aioboto3~=13.3.0
|
|
|
3 |
aiohttp~=3.11.11
|
|
|
4 |
|
5 |
# database packages
|
6 |
graspologic
|
@@ -9,14 +11,20 @@ hnswlib
|
|
9 |
nano-vectordb
|
10 |
neo4j~=5.27.0
|
11 |
networkx~=3.2.1
|
|
|
|
|
12 |
ollama~=0.4.4
|
13 |
openai~=1.58.1
|
14 |
oracledb
|
|
|
15 |
psycopg[binary,pool]~=3.2.3
|
|
|
16 |
pymilvus
|
17 |
pymongo
|
18 |
pymysql
|
|
|
19 |
pyvis~=0.3.2
|
|
|
20 |
# lmdeploy[all]
|
21 |
sqlalchemy~=2.0.36
|
22 |
tenacity~=9.0.0
|
@@ -25,14 +33,6 @@ tenacity~=9.0.0
|
|
25 |
# LLM packages
|
26 |
tiktoken~=0.8.0
|
27 |
torch~=2.5.1+cu121
|
|
|
28 |
transformers~=4.47.1
|
29 |
xxhash
|
30 |
-
|
31 |
-
numpy~=2.2.0
|
32 |
-
aiofiles~=24.1.0
|
33 |
-
pydantic~=2.10.4
|
34 |
-
python-dotenv~=1.0.1
|
35 |
-
psycopg-pool~=3.2.4
|
36 |
-
tqdm~=4.67.1
|
37 |
-
asyncpg~=0.30.0
|
38 |
-
setuptools~=70.0.0
|
|
|
1 |
accelerate
|
2 |
aioboto3~=13.3.0
|
3 |
+
aiofiles~=24.1.0
|
4 |
aiohttp~=3.11.11
|
5 |
+
asyncpg~=0.30.0
|
6 |
|
7 |
# database packages
|
8 |
graspologic
|
|
|
11 |
nano-vectordb
|
12 |
neo4j~=5.27.0
|
13 |
networkx~=3.2.1
|
14 |
+
|
15 |
+
numpy~=2.2.0
|
16 |
ollama~=0.4.4
|
17 |
openai~=1.58.1
|
18 |
oracledb
|
19 |
+
psycopg-pool~=3.2.4
|
20 |
psycopg[binary,pool]~=3.2.3
|
21 |
+
pydantic~=2.10.4
|
22 |
pymilvus
|
23 |
pymongo
|
24 |
pymysql
|
25 |
+
python-dotenv~=1.0.1
|
26 |
pyvis~=0.3.2
|
27 |
+
setuptools~=70.0.0
|
28 |
# lmdeploy[all]
|
29 |
sqlalchemy~=2.0.36
|
30 |
tenacity~=9.0.0
|
|
|
33 |
# LLM packages
|
34 |
tiktoken~=0.8.0
|
35 |
torch~=2.5.1+cu121
|
36 |
+
tqdm~=4.67.1
|
37 |
transformers~=4.47.1
|
38 |
xxhash
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|