samuel-z-chen commited on
Commit
359e407
·
1 Parent(s): 3c5ab1e

With a draft for progres_impl

Browse files
examples/lightrag_zhipu_postgres_demo.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import inspect
3
+ import logging
4
+ import os
5
+ import time
6
+ from dotenv import load_dotenv
7
+
8
+ from lightrag import LightRAG, QueryParam
9
+ from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
10
+ from lightrag.llm import ollama_embedding, zhipu_complete
11
+ from lightrag.utils import EmbeddingFunc
12
+
13
+ load_dotenv()
14
+ ROOT_DIR = os.environ.get("ROOT_DIR")
15
+ WORKING_DIR = f"{ROOT_DIR}/dickens-pg"
16
+
17
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
18
+
19
+ if not os.path.exists(WORKING_DIR):
20
+ os.mkdir(WORKING_DIR)
21
+
22
+ # AGE
23
+ os.environ["AGE_GRAPH_NAME"] = "dickens"
24
+
25
+ postgres_db = PostgreSQLDB(
26
+ config={
27
+ "host": "localhost",
28
+ "port": 15432,
29
+ "user": "rag",
30
+ "password": "rag",
31
+ "database": "rag",
32
+ }
33
+ )
34
+
35
+
36
+ async def main():
37
+ await postgres_db.initdb()
38
+ # Check if PostgreSQL DB tables exist, if not, tables will be created
39
+ await postgres_db.check_tables()
40
+
41
+ rag = LightRAG(
42
+ working_dir=WORKING_DIR,
43
+ llm_model_func=zhipu_complete,
44
+ llm_model_name="glm-4-flashx",
45
+ llm_model_max_async=4,
46
+ llm_model_max_token_size=32768,
47
+ embedding_func=EmbeddingFunc(
48
+ embedding_dim=768,
49
+ max_token_size=8192,
50
+ func=lambda texts: ollama_embedding(
51
+ texts, embed_model="nomic-embed-text", host="http://localhost:11434"
52
+ ),
53
+ ),
54
+ kv_storage="PGKVStorage",
55
+ doc_status_storage="PGDocStatusStorage",
56
+ graph_storage="PGGraphStorage",
57
+ vector_storage="PGVectorStorage"
58
+ )
59
+ # Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool
60
+ rag.doc_status.db = postgres_db
61
+ rag.full_docs.db = postgres_db
62
+ rag.text_chunks.db = postgres_db
63
+ rag.llm_response_cache.db = postgres_db
64
+ rag.key_string_value_json_storage_cls.db = postgres_db
65
+ rag.chunks_vdb.db = postgres_db
66
+ rag.relationships_vdb.db = postgres_db
67
+ rag.entities_vdb.db = postgres_db
68
+ rag.graph_storage_cls.db = postgres_db
69
+ rag.chunk_entity_relation_graph.db = postgres_db
70
+ await rag.chunk_entity_relation_graph.check_graph_exists()
71
+ # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
72
+ rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
73
+
74
+ with open(f"{ROOT_DIR}/book.txt", "r", encoding="utf-8") as f:
75
+ await rag.ainsert(f.read())
76
+
77
+ print("==== Trying to test the rag queries ====")
78
+ print("**** Start Naive Query ****")
79
+ start_time = time.time()
80
+ # Perform naive search
81
+ print(
82
+ await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="naive"))
83
+ )
84
+ print(f"Naive Query Time: {time.time() - start_time} seconds")
85
+ # Perform local search
86
+ print("**** Start Local Query ****")
87
+ start_time = time.time()
88
+ print(
89
+ await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="local"))
90
+ )
91
+ print(f"Local Query Time: {time.time() - start_time} seconds")
92
+ # Perform global search
93
+ print("**** Start Global Query ****")
94
+ start_time = time.time()
95
+ print(
96
+ await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="global"))
97
+ )
98
+ print(f"Global Query Time: {time.time() - start_time}")
99
+ # Perform hybrid search
100
+ print("**** Start Hybrid Query ****")
101
+ print(
102
+ await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
103
+ )
104
+ print(f"Hybrid Query Time: {time.time() - start_time} seconds")
105
+
106
+ print("**** Start Stream Query ****")
107
+ start_time = time.time()
108
+ # stream response
109
+ resp = await rag.aquery(
110
+ "What are the top themes in this story?",
111
+ param=QueryParam(mode="hybrid", stream=True),
112
+ )
113
+ print(f"Stream Query Time: {time.time() - start_time} seconds")
114
+ print("**** Done Stream Query ****")
115
+
116
+ if inspect.isasyncgen(resp):
117
+ asyncio.run(print_stream(resp))
118
+ else:
119
+ print(resp)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ asyncio.run(main())
124
+
125
+
126
+ async def print_stream(stream):
127
+ async for chunk in stream:
128
+ print(chunk, end="", flush=True)
129
+
130
+
131
+
132
+
133
+
lightrag/kg/postgres_impl.py ADDED
@@ -0,0 +1,1162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import inspect
3
+ import json
4
+ import os
5
+ import time
6
+ from dataclasses import dataclass
7
+ from typing import Union, List, Dict, Set, Any, Tuple
8
+ import numpy as np
9
+ import asyncpg
10
+ import sys
11
+ from tqdm.asyncio import tqdm as tqdm_async
12
+ from tenacity import (
13
+ retry,
14
+ retry_if_exception_type,
15
+ stop_after_attempt,
16
+ wait_exponential,
17
+ )
18
+
19
+ from ..utils import logger
20
+ from ..base import (
21
+ BaseKVStorage,
22
+ BaseVectorStorage, DocStatusStorage, DocStatus, DocProcessingStatus, BaseGraphStorage,
23
+ )
24
+
25
+ if sys.platform.startswith("win"):
26
+ import asyncio.windows_events
27
+
28
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
29
+
30
+
31
+ class PostgreSQLDB:
32
+ def __init__(self, config, **kwargs):
33
+ self.pool = None
34
+ self.host = config.get("host", "localhost")
35
+ self.port = config.get("port", 5432)
36
+ self.user = config.get("user", "postgres")
37
+ self.password = config.get("password", None)
38
+ self.database = config.get("database", "postgres")
39
+ self.workspace = config.get("workspace", 'default')
40
+ self.max = 12
41
+ self.increment = 1
42
+ logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier")
43
+
44
+ if self.user is None or self.password is None or self.database is None:
45
+ raise ValueError("Missing database user, password, or database in addon_params")
46
+
47
+
48
+ async def initdb(self):
49
+ try:
50
+ self.pool = await asyncpg.create_pool(
51
+ user=self.user,
52
+ password=self.password,
53
+ database=self.database,
54
+ host=self.host,
55
+ port=self.port,
56
+ min_size=1,
57
+ max_size=self.max
58
+ )
59
+
60
+ logger.info(f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}")
61
+ except Exception as e:
62
+ logger.error(f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}")
63
+ logger.error(f"PostgreSQL database error: {e}")
64
+ raise
65
+
66
+ async def check_tables(self):
67
+ for k, v in TABLES.items():
68
+ try:
69
+ await self.query("SELECT 1 FROM {k} LIMIT 1".format(k=k))
70
+ except Exception as e:
71
+ logger.error(f"Failed to check table {k} in PostgreSQL database")
72
+ logger.error(f"PostgreSQL database error: {e}")
73
+ try:
74
+ await self.execute(v["ddl"])
75
+ logger.info(f"Created table {k} in PostgreSQL database")
76
+ except Exception as e:
77
+ logger.error(f"Failed to create table {k} in PostgreSQL database")
78
+ logger.error(f"PostgreSQL database error: {e}")
79
+
80
+ logger.info("Finished checking all tables in PostgreSQL database")
81
+
82
+
83
+ async def query(
84
+ self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False
85
+ ) -> Union[dict, None, list[dict]]:
86
+ async with self.pool.acquire() as connection:
87
+ try:
88
+ if for_age:
89
+ await connection.execute('SET search_path = ag_catalog, "$user", public')
90
+ if params:
91
+ rows = await connection.fetch(sql, *params.values())
92
+ else:
93
+ rows = await connection.fetch(sql)
94
+
95
+ if multirows:
96
+ if rows:
97
+ columns = [col for col in rows[0].keys()]
98
+ # print("columns", columns.__class__, columns)
99
+ # print("rows", rows)
100
+ data = [dict(zip(columns, row)) for row in rows]
101
+ # print("data", data)
102
+ else:
103
+ data = []
104
+ else:
105
+ if rows:
106
+ columns = rows[0].keys()
107
+ data = dict(zip(columns, rows[0]))
108
+ else:
109
+ data = None
110
+ return data
111
+ except Exception as e:
112
+ logger.error(f"PostgreSQL database error: {e}")
113
+ print(sql)
114
+ print(params)
115
+ raise
116
+
117
+ async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False):
118
+ try:
119
+ async with self.pool.acquire() as connection:
120
+ if for_age:
121
+ await connection.execute('SET search_path = ag_catalog, "$user", public')
122
+
123
+ if data is None:
124
+ await connection.execute(sql)
125
+ else:
126
+ await connection.execute(sql, *data.values())
127
+ except Exception as e:
128
+ logger.error(f"PostgreSQL database error: {e}")
129
+ print(sql)
130
+ print(data)
131
+ raise
132
+
133
+
134
+ @dataclass
135
+ class PGKVStorage(BaseKVStorage):
136
+ db:PostgreSQLDB = None
137
+
138
+ def __post_init__(self):
139
+ self._data = {}
140
+ self._max_batch_size = self.global_config["embedding_batch_num"]
141
+
142
+ ################ QUERY METHODS ################
143
+
144
+ async def get_by_id(self, id: str) -> Union[dict, None]:
145
+ """Get doc_full data by id."""
146
+ sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
147
+ params = {"workspace": self.db.workspace, "id": id}
148
+ if "llm_response_cache" == self.namespace:
149
+ array_res = await self.db.query(sql, params, multirows=True)
150
+ res = {}
151
+ for row in array_res:
152
+ res[row["id"]] = row
153
+ else:
154
+ res = await self.db.query(sql, params)
155
+ if res:
156
+ return res
157
+ else:
158
+ return None
159
+
160
+ # Query by id
161
+ async def get_by_ids(self, ids: List[str], fields=None) -> Union[List[dict], None]:
162
+ """Get doc_chunks data by id"""
163
+ sql = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
164
+ ids=",".join([f"'{id}'" for id in ids])
165
+ )
166
+ params = {"workspace": self.db.workspace}
167
+ if "llm_response_cache" == self.namespace:
168
+ array_res = await self.db.query(sql, params, multirows=True)
169
+ modes = set()
170
+ dict_res: dict[str, dict] = {}
171
+ for row in array_res:
172
+ modes.add(row["mode"])
173
+ for mode in modes:
174
+ if mode not in dict_res:
175
+ dict_res[mode] = {}
176
+ for row in array_res:
177
+ dict_res[row["mode"]][row["id"]] = row
178
+ res = [{k:v} for k,v in dict_res.items()]
179
+ else:
180
+ res = await self.db.query(sql, params, multirows=True)
181
+ if res:
182
+ return res
183
+ else:
184
+ return None
185
+
186
+ async def filter_keys(self, keys: List[str]) -> Set[str]:
187
+ """Filter out duplicated content"""
188
+ sql = SQL_TEMPLATES["filter_keys"].format(
189
+ table_name=NAMESPACE_TABLE_MAP[self.namespace], ids=",".join([f"'{id}'" for id in keys])
190
+ )
191
+ params = {"workspace": self.db.workspace}
192
+ try:
193
+ res = await self.db.query(sql, params, multirows=True)
194
+ if res:
195
+ exist_keys = [key["id"] for key in res]
196
+ else:
197
+ exist_keys = []
198
+ data = set([s for s in keys if s not in exist_keys])
199
+ return data
200
+ except Exception as e:
201
+ logger.error(f"PostgreSQL database error: {e}")
202
+ print(sql)
203
+ print(params)
204
+
205
+
206
+ ################ INSERT METHODS ################
207
+ async def upsert(self, data: Dict[str, dict]):
208
+ left_data = {k: v for k, v in data.items() if k not in self._data}
209
+ self._data.update(left_data)
210
+ if self.namespace == "text_chunks":
211
+ pass
212
+ elif self.namespace == "full_docs":
213
+ for k, v in self._data.items():
214
+ upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
215
+ data = {
216
+ "id": k,
217
+ "content": v["content"],
218
+ "workspace": self.db.workspace,
219
+ }
220
+ await self.db.execute(upsert_sql, data)
221
+ elif self.namespace == "llm_response_cache":
222
+ for mode, items in self._data.items():
223
+ for k, v in items.items():
224
+ upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
225
+ data = {
226
+ "workspace": self.db.workspace,
227
+ "id": k,
228
+ "original_prompt": v["original_prompt"],
229
+ "return": v["return"],
230
+ "mode": mode,
231
+ }
232
+ await self.db.execute(upsert_sql, data)
233
+
234
+ return left_data
235
+
236
+ async def index_done_callback(self):
237
+ if self.namespace in ["full_docs", "text_chunks"]:
238
+ logger.info("full doc and chunk data had been saved into postgresql db!")
239
+
240
+
241
+ @dataclass
242
+ class PGVectorStorage(BaseVectorStorage):
243
+ cosine_better_than_threshold: float = 0.2
244
+ db:PostgreSQLDB = None
245
+
246
+ def __post_init__(self):
247
+ self._max_batch_size = self.global_config["embedding_batch_num"]
248
+ self.cosine_better_than_threshold = self.global_config.get(
249
+ "cosine_better_than_threshold", self.cosine_better_than_threshold
250
+ )
251
+
252
+ def _upsert_chunks(self, item: dict):
253
+ try:
254
+ upsert_sql = SQL_TEMPLATES["upsert_chunk"]
255
+ data = {
256
+ "workspace": self.db.workspace,
257
+ "id": item["__id__"],
258
+ "tokens": item["tokens"],
259
+ "chunk_order_index": item["chunk_order_index"],
260
+ "full_doc_id": item["full_doc_id"],
261
+ "content": item["content"],
262
+ "content_vector": json.dumps(item["__vector__"].tolist()),
263
+ }
264
+ except Exception as e:
265
+ logger.error(f"Error to prepare upsert sql: {e}")
266
+ print(item)
267
+ raise e
268
+ return upsert_sql, data
269
+
270
+ def _upsert_entities(self, item: dict):
271
+ upsert_sql = SQL_TEMPLATES["upsert_entity"]
272
+ data = {
273
+ "workspace": self.db.workspace,
274
+ "id": item["__id__"],
275
+ "entity_name": item["entity_name"],
276
+ "content": item["content"],
277
+ "content_vector": json.dumps(item["__vector__"].tolist()),
278
+ }
279
+ return upsert_sql, data
280
+ def _upsert_relationships(self, item: dict):
281
+ upsert_sql = SQL_TEMPLATES["upsert_relationship"]
282
+ data = {
283
+ "workspace": self.db.workspace,
284
+ "id": item["__id__"],
285
+ "source_id": item["src_id"],
286
+ "target_id": item["tgt_id"],
287
+ "content": item["content"],
288
+ "content_vector": json.dumps(item["__vector__"].tolist()),
289
+ }
290
+ return upsert_sql, data
291
+
292
+ async def upsert(self, data: Dict[str, dict]):
293
+ logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
294
+ if not len(data):
295
+ logger.warning("You insert an empty data to vector DB")
296
+ return []
297
+ current_time = time.time()
298
+ list_data = [
299
+ {
300
+ "__id__": k,
301
+ "__created_at__": current_time,
302
+ **{k1: v1 for k1, v1 in v.items()},
303
+ }
304
+ for k, v in data.items()
305
+ ]
306
+ contents = [v["content"] for v in data.values()]
307
+ batches = [
308
+ contents[i : i + self._max_batch_size]
309
+ for i in range(0, len(contents), self._max_batch_size)
310
+ ]
311
+
312
+ async def wrapped_task(batch):
313
+ result = await self.embedding_func(batch)
314
+ pbar.update(1)
315
+ return result
316
+
317
+ embedding_tasks = [wrapped_task(batch) for batch in batches]
318
+ pbar = tqdm_async(
319
+ total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
320
+ )
321
+ embeddings_list = await asyncio.gather(*embedding_tasks)
322
+
323
+ embeddings = np.concatenate(embeddings_list)
324
+ for i, d in enumerate(list_data):
325
+ d["__vector__"] = embeddings[i]
326
+ for item in list_data:
327
+ if self.namespace == "chunks":
328
+ upsert_sql, data = self._upsert_chunks(item)
329
+ elif self.namespace == "entities":
330
+ upsert_sql, data = self._upsert_entities(item)
331
+ elif self.namespace == "relationships":
332
+ upsert_sql, data = self._upsert_relationships(item)
333
+ else:
334
+ raise ValueError(f"{self.namespace} is not supported")
335
+
336
+ await self.db.execute(upsert_sql, data)
337
+
338
+
339
+
340
+ async def index_done_callback(self):
341
+ logger.info("vector data had been saved into postgresql db!")
342
+
343
+ #################### query method ###############
344
+ async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
345
+ """从向量数据库中查询数据"""
346
+ embeddings = await self.embedding_func([query])
347
+ embedding = embeddings[0]
348
+ embedding_string = ",".join(map(str, embedding))
349
+ # print("Namespace", self.namespace)
350
+
351
+ sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
352
+ # print("sql is: ", sql)
353
+ params = {
354
+ "workspace": self.db.workspace,
355
+ "better_than_threshold": self.cosine_better_than_threshold,
356
+ "top_k": top_k,
357
+ }
358
+ # print("params is: ", params)
359
+ results = await self.db.query(sql, params=params, multirows=True)
360
+ print("vector search result:", results)
361
+ return results
362
+
363
+ @dataclass
364
+ class PGDocStatusStorage(DocStatusStorage):
365
+ """PostgreSQL implementation of document status storage"""
366
+ db:PostgreSQLDB = None
367
+
368
+ def __post_init__(self):
369
+ pass
370
+
371
+ async def filter_keys(self, data: list[str]) -> set[str]:
372
+ """Return keys that don't exist in storage"""
373
+ sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})"
374
+ result = await self.db.query(sql, {'workspace': self.db.workspace}, True)
375
+ # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
376
+ if result is None:
377
+ return set(data)
378
+ else:
379
+ existed = set([element['id'] for element in result])
380
+ return set(data) - existed
381
+
382
+ async def get_status_counts(self) -> Dict[str, int]:
383
+ """Get counts of documents in each status"""
384
+ sql = '''SELECT status as "status", COUNT(1) as "count"
385
+ FROM LIGHTRAG_DOC_STATUS
386
+ where workspace=$1 GROUP BY STATUS
387
+ '''
388
+ result = await self.db.query(sql, {'workspace': self.db.workspace}, True)
389
+ # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
390
+ counts = {}
391
+ for doc in result:
392
+ counts[doc["status"]] = doc["count"]
393
+ return counts
394
+
395
+ async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]:
396
+ """Get all documents by status"""
397
+ sql = 'select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1'
398
+ params = {'workspace': self.db.workspace, 'status': status}
399
+ result = await self.db.query(sql, params, True)
400
+ # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
401
+ # Converting to be a dict
402
+ return {element["id"]:
403
+ DocProcessingStatus(content_summary=element["content_summary"],
404
+ content_length=element["content_length"],
405
+ status=element["status"],
406
+ created_at=element["created_at"],
407
+ updated_at=element["updated_at"],
408
+ chunks_count=element["chunks_count"]) for element in result}
409
+
410
+ async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
411
+ """Get all failed documents"""
412
+ return await self.get_docs_by_status(DocStatus.FAILED)
413
+
414
+ async def get_pending_docs(self) -> Dict[str, DocProcessingStatus]:
415
+ """Get all pending documents"""
416
+ return await self.get_docs_by_status(DocStatus.PENDING)
417
+
418
+ async def index_done_callback(self):
419
+ """Save data after indexing, but for PostgreSQL, we already saved them during the upsert stage, so no action to take here"""
420
+ logger.info("Doc status had been saved into postgresql db!")
421
+
422
+ async def upsert(self, data: dict[str, dict]):
423
+ """Update or insert document status
424
+
425
+ Args:
426
+ data: Dictionary of document IDs and their status data
427
+ """
428
+ sql = """insert into LIGHTRAG_DOC_STATUS(workspace,id,content_summary,content_length,chunks_count,status)
429
+ values($1,$2,$3,$4,$5,$6)
430
+ on conflict(id,workspace) do update set
431
+ content_summary = EXCLUDED.content_summary,
432
+ content_length = EXCLUDED.content_length,
433
+ chunks_count = EXCLUDED.chunks_count,
434
+ status = EXCLUDED.status,
435
+ updated_at = CURRENT_TIMESTAMP"""
436
+ for k, v in data.items():
437
+ # chunks_count is optional
438
+ await self.db.execute(sql, {
439
+ "workspace": self.db.workspace,
440
+ "id": k,
441
+ "content_summary": v["content_summary"],
442
+ "content_length": v["content_length"],
443
+ "chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
444
+ "status": v["status"],
445
+ })
446
+ return data
447
+
448
+
449
+ class PGGraphQueryException(Exception):
450
+ """Exception for the AGE queries."""
451
+
452
+ def __init__(self, exception: Union[str, Dict]) -> None:
453
+ if isinstance(exception, dict):
454
+ self.message = exception["message"] if "message" in exception else "unknown"
455
+ self.details = exception["details"] if "details" in exception else "unknown"
456
+ else:
457
+ self.message = exception
458
+ self.details = "unknown"
459
+
460
+ def get_message(self) -> str:
461
+ return self.message
462
+
463
+ def get_details(self) -> Any:
464
+ return self.details
465
+
466
+
467
+ @dataclass
468
+ class PGGraphStorage(BaseGraphStorage):
469
+ db:PostgreSQLDB = None
470
+
471
+ @staticmethod
472
+ def load_nx_graph(file_name):
473
+ print("no preloading of graph with AGE in production")
474
+
475
+ def __init__(self, namespace, global_config, embedding_func):
476
+ super().__init__(
477
+ namespace=namespace,
478
+ global_config=global_config,
479
+ embedding_func=embedding_func,
480
+ )
481
+ self.graph_name = os.environ["AGE_GRAPH_NAME"]
482
+ self._node_embed_algorithms = {
483
+ "node2vec": self._node2vec_embed,
484
+ }
485
+
486
+
487
+ async def index_done_callback(self):
488
+ print("KG successfully indexed.")
489
+
490
+ async def check_graph_exists(self):
491
+ try:
492
+ res = await self.db.query(f"SELECT * FROM ag_catalog.ag_graph WHERE name = '{self.graph_name}'")
493
+ if res:
494
+ logger.info(f"Graph {self.graph_name} exists.")
495
+ else:
496
+ logger.info(f"Graph {self.graph_name} does not exist. Creating...")
497
+ await self.db.execute(f"SELECT create_graph('{self.graph_name}')", for_age=True)
498
+ logger.info(f"Graph {self.graph_name} created.")
499
+ except Exception as e:
500
+ logger.info(f"Failed to check/create graph {self.graph_name}:", e)
501
+ raise e
502
+
503
+ @staticmethod
504
+ def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
505
+ """
506
+ Convert a record returned from an age query to a dictionary
507
+
508
+ Args:
509
+ record (): a record from an age query result
510
+
511
+ Returns:
512
+ Dict[str, Any]: a dictionary representation of the record where
513
+ the dictionary key is the field name and the value is the
514
+ value converted to a python type
515
+ """
516
+ # result holder
517
+ d = {}
518
+
519
+ # prebuild a mapping of vertex_id to vertex mappings to be used
520
+ # later to build edges
521
+ vertices = {}
522
+ for k in record.keys():
523
+ v = record[k]
524
+ # agtype comes back '{key: value}::type' which must be parsed
525
+ if isinstance(v, str) and "::" in v:
526
+ dtype = v.split("::")[-1]
527
+ v = v.split("::")[0]
528
+ if dtype == "vertex":
529
+ vertex = json.loads(v)
530
+ vertices[vertex["id"]] = vertex.get("properties")
531
+
532
+ # iterate returned fields and parse appropriately
533
+ for k in record.keys():
534
+ v = record[k]
535
+ if isinstance(v, str) and "::" in v:
536
+ dtype = v.split("::")[-1]
537
+ v = v.split("::")[0]
538
+ else:
539
+ dtype = ""
540
+
541
+ if dtype == "vertex":
542
+ vertex = json.loads(v)
543
+ field = json.loads(v).get("properties")
544
+ if not field:
545
+ field = {}
546
+ field["label"] = PGGraphStorage._decode_graph_label(vertex["label"])
547
+ d[k] = field
548
+ # convert edge from id-label->id by replacing id with node information
549
+ # we only do this if the vertex was also returned in the query
550
+ # this is an attempt to be consistent with neo4j implementation
551
+ elif dtype == "edge":
552
+ edge = json.loads(v)
553
+ d[k] = (
554
+ vertices.get(edge["start_id"], {}),
555
+ edge[
556
+ "label"
557
+ ], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
558
+ vertices.get(edge["end_id"], {}),
559
+ )
560
+ else:
561
+ d[k] = json.loads(v) if isinstance(v, str) else v
562
+
563
+ return d
564
+
565
+ @staticmethod
566
+ def _format_properties(
567
+ properties: Dict[str, Any], _id: Union[str, None] = None
568
+ ) -> str:
569
+ """
570
+ Convert a dictionary of properties to a string representation that
571
+ can be used in a cypher query insert/merge statement.
572
+
573
+ Args:
574
+ properties (Dict[str,str]): a dictionary containing node/edge properties
575
+ id (Union[str, None]): the id of the node or None if none exists
576
+
577
+ Returns:
578
+ str: the properties dictionary as a properly formatted string
579
+ """
580
+ props = []
581
+ # wrap property key in backticks to escape
582
+ for k, v in properties.items():
583
+ prop = f"`{k}`: {json.dumps(v)}"
584
+ props.append(prop)
585
+ if _id is not None and "id" not in properties:
586
+ props.append(
587
+ f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}"
588
+ )
589
+ return "{" + ", ".join(props) + "}"
590
+
591
+ @staticmethod
592
+ def _encode_graph_label(label: str) -> str:
593
+ """
594
+ Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
595
+
596
+ Args:
597
+ label (str): the original label
598
+
599
+ Returns:
600
+ str: the encoded label
601
+ """
602
+ return "x" + label.encode().hex()
603
+
604
+ @staticmethod
605
+ def _decode_graph_label(encoded_label: str) -> str:
606
+ """
607
+ Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
608
+
609
+ Args:
610
+ encoded_label (str): the encoded label
611
+
612
+ Returns:
613
+ str: the decoded label
614
+ """
615
+ return bytes.fromhex(encoded_label.removeprefix("x")).decode()
616
+
617
+ @staticmethod
618
+ def _get_col_name(field: str, idx: int) -> str:
619
+ """
620
+ Convert a cypher return field to a pgsql select field
621
+ If possible keep the cypher column name, but create a generic name if necessary
622
+
623
+ Args:
624
+ field (str): a return field from a cypher query to be formatted for pgsql
625
+ idx (int): the position of the field in the return statement
626
+
627
+ Returns:
628
+ str: the field to be used in the pgsql select statement
629
+ """
630
+ # remove white space
631
+ field = field.strip()
632
+ # if an alias is provided for the field, use it
633
+ if " as " in field:
634
+ return field.split(" as ")[-1].strip()
635
+ # if the return value is an unnamed primitive, give it a generic name
636
+ if field.isnumeric() or field in ("true", "false", "null"):
637
+ return f"column_{idx}"
638
+ # otherwise return the value stripping out some common special chars
639
+ return field.replace("(", "_").replace(")", "")
640
+
641
+ @staticmethod
642
+ def _wrap_query(query: str, graph_name: str, **params: str) -> str:
643
+ """
644
+ Convert a cypher query to an Apache Age compatible
645
+ sql query by wrapping the cypher query in ag_catalog.cypher,
646
+ casting results to agtype and building a select statement
647
+
648
+ Args:
649
+ query (str): a valid cypher query
650
+ graph_name (str): the name of the graph to query
651
+ params (dict): parameters for the query
652
+
653
+ Returns:
654
+ str: an equivalent pgsql query
655
+ """
656
+
657
+ # pgsql template
658
+ template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
659
+ {query}
660
+ $$) AS ({fields});"""
661
+
662
+ # if there are any returned fields they must be added to the pgsql query
663
+ if "return" in query.lower():
664
+ # parse return statement to identify returned fields
665
+ fields = (
666
+ query.lower()
667
+ .split("return")[-1]
668
+ .split("distinct")[-1]
669
+ .split("order by")[0]
670
+ .split("skip")[0]
671
+ .split("limit")[0]
672
+ .split(",")
673
+ )
674
+
675
+ # raise exception if RETURN * is found as we can't resolve the fields
676
+ if "*" in [x.strip() for x in fields]:
677
+ raise ValueError(
678
+ "AGE graph does not support 'RETURN *'"
679
+ + " statements in Cypher queries"
680
+ )
681
+
682
+ # get pgsql formatted field names
683
+ fields = [
684
+ PGGraphStorage._get_col_name(field, idx) for idx, field in enumerate(fields)
685
+ ]
686
+
687
+ # build resulting pgsql relation
688
+ fields_str = ", ".join(
689
+ [field.split(".")[-1] + " agtype" for field in fields]
690
+ )
691
+
692
+ # if no return statement we still need to return a single field of type agtype
693
+ else:
694
+ fields_str = "a agtype"
695
+
696
+ select_str = "*"
697
+
698
+ return template.format(
699
+ graph_name=graph_name,
700
+ query=query.format(**params),
701
+ fields=fields_str,
702
+ projection=select_str,
703
+ )
704
+
705
+ async def _query(self, query: str, readonly=True, **params: str) -> List[Dict[str, Any]]:
706
+ """
707
+ Query the graph by taking a cypher query, converting it to an
708
+ age compatible query, executing it and converting the result
709
+
710
+ Args:
711
+ query (str): a cypher query to be executed
712
+ params (dict): parameters for the query
713
+
714
+ Returns:
715
+ List[Dict[str, Any]]: a list of dictionaries containing the result set
716
+ """
717
+ # convert cypher query to pgsql/age query
718
+ wrapped_query = self._wrap_query(query, self.graph_name, **params)
719
+
720
+ # execute the query, rolling back on an error
721
+ try:
722
+ if readonly:
723
+ data = await self.db.query(wrapped_query, multirows=True, for_age=True)
724
+ else:
725
+ data = await self.db.execute(wrapped_query, for_age=True)
726
+ except Exception as e:
727
+ raise PGGraphQueryException(
728
+ {
729
+ "message": f"Error executing graph query: {query.format(**params)}",
730
+ "wrapped": wrapped_query,
731
+ "detail": str(e),
732
+ }
733
+ ) from e
734
+
735
+ if data is None:
736
+ result = []
737
+ # decode records
738
+ else:
739
+ result = [PGGraphStorage._record_to_dict(d) for d in data]
740
+
741
+ return result
742
+
743
+ async def has_node(self, node_id: str) -> bool:
744
+ entity_name_label = node_id.strip('"')
745
+
746
+ query = """
747
+ MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists
748
+ """
749
+ params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
750
+ single_result = (await self._query(query, **params))[0]
751
+ logger.debug(
752
+ "{%s}:query:{%s}:result:{%s}",
753
+ inspect.currentframe().f_code.co_name,
754
+ query.format(**params),
755
+ single_result["node_exists"],
756
+ )
757
+
758
+ return single_result["node_exists"]
759
+
760
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
761
+ entity_name_label_source = source_node_id.strip('"')
762
+ entity_name_label_target = target_node_id.strip('"')
763
+
764
+ query = """
765
+ MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
766
+ RETURN COUNT(r) > 0 AS edge_exists
767
+ """
768
+ params = {
769
+ "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
770
+ "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
771
+ }
772
+ single_result = (await self._query(query, **params))[0]
773
+ logger.debug(
774
+ "{%s}:query:{%s}:result:{%s}",
775
+ inspect.currentframe().f_code.co_name,
776
+ query.format(**params),
777
+ single_result["edge_exists"],
778
+ )
779
+ return single_result["edge_exists"]
780
+
781
+ async def get_node(self, node_id: str) -> Union[dict, None]:
782
+ entity_name_label = node_id.strip('"')
783
+ query = """
784
+ MATCH (n:`{label}`) RETURN n
785
+ """
786
+ params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
787
+ record = await self._query(query, **params)
788
+ if record:
789
+ node = record[0]
790
+ node_dict = node["n"]
791
+ logger.debug(
792
+ "{%s}: query: {%s}, result: {%s}",
793
+ inspect.currentframe().f_code.co_name,
794
+ query.format(**params),
795
+ node_dict,
796
+ )
797
+ return node_dict
798
+ return None
799
+
800
+ async def node_degree(self, node_id: str) -> int:
801
+ entity_name_label = node_id.strip('"')
802
+
803
+ query = """
804
+ MATCH (n:`{label}`)-[]->(x)
805
+ RETURN count(x) AS total_edge_count
806
+ """
807
+ params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
808
+ record = (await self._query(query, **params))[0]
809
+ if record:
810
+ edge_count = int(record["total_edge_count"])
811
+ logger.debug(
812
+ "{%s}:query:{%s}:result:{%s}",
813
+ inspect.currentframe().f_code.co_name,
814
+ query.format(**params),
815
+ edge_count,
816
+ )
817
+ return edge_count
818
+
819
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
820
+ entity_name_label_source = src_id.strip('"')
821
+ entity_name_label_target = tgt_id.strip('"')
822
+ src_degree = await self.node_degree(entity_name_label_source)
823
+ trg_degree = await self.node_degree(entity_name_label_target)
824
+
825
+ # Convert None to 0 for addition
826
+ src_degree = 0 if src_degree is None else src_degree
827
+ trg_degree = 0 if trg_degree is None else trg_degree
828
+
829
+ degrees = int(src_degree) + int(trg_degree)
830
+ logger.debug(
831
+ "{%s}:query:src_Degree+trg_degree:result:{%s}",
832
+ inspect.currentframe().f_code.co_name,
833
+ degrees,
834
+ )
835
+ return degrees
836
+
837
+ async def get_edge(
838
+ self, source_node_id: str, target_node_id: str
839
+ ) -> Union[dict, None]:
840
+ """
841
+ Find all edges between nodes of two given labels
842
+
843
+ Args:
844
+ source_node_label (str): Label of the source nodes
845
+ target_node_label (str): Label of the target nodes
846
+
847
+ Returns:
848
+ list: List of all relationships/edges found
849
+ """
850
+ entity_name_label_source = source_node_id.strip('"')
851
+ entity_name_label_target = target_node_id.strip('"')
852
+
853
+ query = """
854
+ MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
855
+ RETURN properties(r) as edge_properties
856
+ LIMIT 1
857
+ """
858
+ params = {
859
+ "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
860
+ "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
861
+ }
862
+ record = await self._query(query, **params)
863
+ if record and record[0] and record[0]["edge_properties"]:
864
+ result = record[0]["edge_properties"]
865
+ logger.debug(
866
+ "{%s}:query:{%s}:result:{%s}",
867
+ inspect.currentframe().f_code.co_name,
868
+ query.format(**params),
869
+ result,
870
+ )
871
+ return result
872
+
873
+ async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
874
+ """
875
+ Retrieves all edges (relationships) for a particular node identified by its label.
876
+ :return: List of dictionaries containing edge information
877
+ """
878
+ node_label = source_node_id.strip('"')
879
+
880
+ query = """
881
+ MATCH (n:`{label}`)
882
+ OPTIONAL MATCH (n)-[r]-(connected)
883
+ RETURN n, r, connected
884
+ """
885
+ params = {"label": PGGraphStorage._encode_graph_label(node_label)}
886
+ results = await self._query(query, **params)
887
+ edges = []
888
+ for record in results:
889
+ source_node = record["n"] if record["n"] else None
890
+ connected_node = record["connected"] if record["connected"] else None
891
+
892
+ source_label = (
893
+ source_node["label"] if source_node and source_node["label"] else None
894
+ )
895
+ target_label = (
896
+ connected_node["label"]
897
+ if connected_node and connected_node["label"]
898
+ else None
899
+ )
900
+
901
+ if source_label and target_label:
902
+ edges.append((source_label, target_label))
903
+
904
+ return edges
905
+
906
+ @retry(
907
+ stop=stop_after_attempt(3),
908
+ wait=wait_exponential(multiplier=1, min=4, max=10),
909
+ retry=retry_if_exception_type((PGGraphQueryException,)),
910
+ )
911
+ async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
912
+ """
913
+ Upsert a node in the AGE database.
914
+
915
+ Args:
916
+ node_id: The unique identifier for the node (used as label)
917
+ node_data: Dictionary of node properties
918
+ """
919
+ label = node_id.strip('"')
920
+ properties = node_data
921
+
922
+ query = """
923
+ MERGE (n:`{label}`)
924
+ SET n += {properties}
925
+ """
926
+ params = {
927
+ "label": PGGraphStorage._encode_graph_label(label),
928
+ "properties": PGGraphStorage._format_properties(properties),
929
+ }
930
+ try:
931
+ await self._query(query, readonly=False, **params)
932
+ logger.debug(
933
+ "Upserted node with label '{%s}' and properties: {%s}",
934
+ label,
935
+ properties,
936
+ )
937
+ except Exception as e:
938
+ logger.error("Error during upsert: {%s}", e)
939
+ raise
940
+
941
+ @retry(
942
+ stop=stop_after_attempt(3),
943
+ wait=wait_exponential(multiplier=1, min=4, max=10),
944
+ retry=retry_if_exception_type((PGGraphQueryException,)),
945
+ )
946
+ async def upsert_edge(
947
+ self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
948
+ ):
949
+ """
950
+ Upsert an edge and its properties between two nodes identified by their labels.
951
+
952
+ Args:
953
+ source_node_id (str): Label of the source node (used as identifier)
954
+ target_node_id (str): Label of the target node (used as identifier)
955
+ edge_data (dict): Dictionary of properties to set on the edge
956
+ """
957
+ source_node_label = source_node_id.strip('"')
958
+ target_node_label = target_node_id.strip('"')
959
+ edge_properties = edge_data
960
+
961
+ query = """
962
+ MATCH (source:`{src_label}`)
963
+ WITH source
964
+ MATCH (target:`{tgt_label}`)
965
+ MERGE (source)-[r:DIRECTED]->(target)
966
+ SET r += {properties}
967
+ RETURN r
968
+ """
969
+ params = {
970
+ "src_label": PGGraphStorage._encode_graph_label(source_node_label),
971
+ "tgt_label": PGGraphStorage._encode_graph_label(target_node_label),
972
+ "properties": PGGraphStorage._format_properties(edge_properties),
973
+ }
974
+ try:
975
+ await self._query(query, readonly=False, **params)
976
+ logger.debug(
977
+ "Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
978
+ source_node_label,
979
+ target_node_label,
980
+ edge_properties,
981
+ )
982
+ except Exception as e:
983
+ logger.error("Error during edge upsert: {%s}", e)
984
+ raise
985
+
986
+ async def _node2vec_embed(self):
987
+ print("Implemented but never called.")
988
+
989
+
990
+ NAMESPACE_TABLE_MAP = {
991
+ "full_docs": "LIGHTRAG_DOC_FULL",
992
+ "text_chunks": "LIGHTRAG_DOC_CHUNKS",
993
+ "chunks": "LIGHTRAG_DOC_CHUNKS",
994
+ "entities": "LIGHTRAG_VDB_ENTITY",
995
+ "relationships": "LIGHTRAG_VDB_RELATION",
996
+ "doc_status": "LIGHTRAG_DOC_STATUS",
997
+ "llm_response_cache": "LIGHTRAG_LLM_CACHE",
998
+ }
999
+
1000
+
1001
+ TABLES = {
1002
+ "LIGHTRAG_DOC_FULL": {
1003
+ "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
1004
+ id VARCHAR(255),
1005
+ workspace VARCHAR(255),
1006
+ doc_name VARCHAR(1024),
1007
+ content TEXT,
1008
+ meta JSONB,
1009
+ createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1010
+ updatetime TIMESTAMP,
1011
+ CONSTRAINT LIGHTRAG_DOC_FULL_PK PRIMARY KEY (workspace, id)
1012
+ )"""
1013
+ },
1014
+ "LIGHTRAG_DOC_CHUNKS": {
1015
+ "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
1016
+ id VARCHAR(255),
1017
+ workspace VARCHAR(255),
1018
+ full_doc_id VARCHAR(256),
1019
+ chunk_order_index INTEGER,
1020
+ tokens INTEGER,
1021
+ content TEXT,
1022
+ content_vector VECTOR,
1023
+ createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1024
+ updatetime TIMESTAMP,
1025
+ CONSTRAINT LIGHTRAG_DOC_CHUNKS_PK PRIMARY KEY (workspace, id)
1026
+ )"""
1027
+ },
1028
+ "LIGHTRAG_VDB_ENTITY": {
1029
+ "ddl": """CREATE TABLE LIGHTRAG_VDB_ENTITY (
1030
+ id VARCHAR(255),
1031
+ workspace VARCHAR(255),
1032
+ entity_name VARCHAR(255),
1033
+ content TEXT,
1034
+ content_vector VECTOR,
1035
+ createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1036
+ updatetime TIMESTAMP,
1037
+ CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
1038
+ )"""
1039
+ },
1040
+ "LIGHTRAG_VDB_RELATION": {
1041
+ "ddl": """CREATE TABLE LIGHTRAG_VDB_RELATION (
1042
+ id VARCHAR(255),
1043
+ workspace VARCHAR(255),
1044
+ source_id VARCHAR(256),
1045
+ target_id VARCHAR(256),
1046
+ content TEXT,
1047
+ content_vector VECTOR,
1048
+ createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1049
+ updatetime TIMESTAMP,
1050
+ CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
1051
+ )"""
1052
+ },
1053
+ "LIGHTRAG_LLM_CACHE": {
1054
+ "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
1055
+ workspace varchar(255) NOT NULL,
1056
+ id varchar(255) NOT NULL,
1057
+ mode varchar(32) NOT NULL,
1058
+ original_prompt TEXT,
1059
+ return TEXT,
1060
+ createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1061
+ updatetime TIMESTAMP,
1062
+ CONSTRAINT LIGHTRAG_LLM_CACHE_PK PRIMARY KEY (workspace, id)
1063
+ )"""
1064
+ },
1065
+ "LIGHTRAG_DOC_STATUS": {
1066
+ "ddl": """CREATE TABLE LIGHTRAG_DOC_STATUS (
1067
+ workspace varchar(255) NOT NULL,
1068
+ id varchar(255) NOT NULL,
1069
+ content_summary varchar(255) NULL,
1070
+ content_length int4 NULL,
1071
+ chunks_count int4 NULL,
1072
+ status varchar(64) NULL,
1073
+ created_at timestamp DEFAULT CURRENT_TIMESTAMP NULL,
1074
+ updated_at timestamp DEFAULT CURRENT_TIMESTAMP NULL,
1075
+ CONSTRAINT LIGHTRAG_DOC_STATUS_PK PRIMARY KEY (workspace, id)
1076
+ )"""
1077
+ },
1078
+ }
1079
+
1080
+
1081
+
1082
+ SQL_TEMPLATES = {
1083
+ # SQL for KVStorage
1084
+ "get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content
1085
+ FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id=$2
1086
+ """,
1087
+ "get_by_id_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
1088
+ chunk_order_index, full_doc_id
1089
+ FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id=$2
1090
+ """,
1091
+ "get_by_id_llm_response_cache": """SELECT id, original_prompt, COALESCE("return", '') as "return", mode
1092
+ FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode=$2
1093
+ """,
1094
+ "get_by_ids_full_docs": """SELECT id, COALESCE(content, '') as content
1095
+ FROM LIGHTRAG_DOC_FULL WHERE workspace=$1 AND id IN ({ids})
1096
+ """,
1097
+ "get_by_ids_text_chunks": """SELECT id, tokens, COALESCE(content, '') as content,
1098
+ chunk_order_index, full_doc_id
1099
+ FROM LIGHTRAG_DOC_CHUNKS WHERE workspace=$1 AND id IN ({ids})
1100
+ """,
1101
+ "get_by_ids_llm_response_cache": """SELECT id, original_prompt, COALESCE("return", '') as "return", mode
1102
+ FROM LIGHTRAG_LLM_CACHE WHERE workspace=$1 AND mode= IN ({ids})
1103
+ """,
1104
+ "filter_keys": "SELECT id FROM {table_name} WHERE workspace=$1 AND id IN ({ids})",
1105
+ "upsert_doc_full": """INSERT INTO LIGHTRAG_DOC_FULL (id, content, workspace)
1106
+ VALUES ($1, $2, $3)
1107
+ ON CONFLICT (workspace,id) DO UPDATE
1108
+ SET content = $2, updatetime = CURRENT_TIMESTAMP
1109
+ """,
1110
+ "upsert_llm_response_cache": """INSERT INTO LIGHTRAG_LLM_CACHE(workspace,id,original_prompt,"return",mode)
1111
+ VALUES ($1, $2, $3, $4, $5)
1112
+ ON CONFLICT (workspace,id) DO UPDATE
1113
+ SET original_prompt = EXCLUDED.original_prompt,
1114
+ "return"=EXCLUDED."return",
1115
+ mode=EXCLUDED.mode,
1116
+ updatetime = CURRENT_TIMESTAMP
1117
+ """,
1118
+ "upsert_chunk": """INSERT INTO LIGHTRAG_DOC_CHUNKS (workspace, id, tokens,
1119
+ chunk_order_index, full_doc_id, content, content_vector)
1120
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
1121
+ ON CONFLICT (workspace,id) DO UPDATE
1122
+ SET tokens=EXCLUDED.tokens,
1123
+ chunk_order_index=EXCLUDED.chunk_order_index,
1124
+ full_doc_id=EXCLUDED.full_doc_id,
1125
+ content = EXCLUDED.content,
1126
+ content_vector=EXCLUDED.content_vector,
1127
+ updatetime = CURRENT_TIMESTAMP
1128
+ """,
1129
+ "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
1130
+ VALUES ($1, $2, $3, $4, $5, $6)
1131
+ ON CONFLICT (workspace,id) DO UPDATE
1132
+ SET entity_name=EXCLUDED.entity_name,
1133
+ content=EXCLUDED.content,
1134
+ content_vector=EXCLUDED.content_vector,
1135
+ updatetime=CURRENT_TIMESTAMP
1136
+ """,
1137
+ "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
1138
+ target_id, content, content_vector)
1139
+ VALUES ($1, $2, $3, $4, $5, $6)
1140
+ ON CONFLICT (workspace,id) DO UPDATE
1141
+ SET source_id=EXCLUDED.source_id,
1142
+ target_id=EXCLUDED.target_id,
1143
+ content=EXCLUDED.content,
1144
+ content_vector=EXCLUDED.content_vector, updatetime = CURRENT_TIMESTAMP
1145
+ """,
1146
+ # SQL for VectorStorage
1147
+ "entities": """SELECT entity_name FROM
1148
+ (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1149
+ FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
1150
+ WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1151
+ """,
1152
+ "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
1153
+ (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1154
+ FROM LIGHTRAG_VDB_RELATION where workspace=$1)
1155
+ WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1156
+ """,
1157
+ "chunks": """SELECT id FROM
1158
+ (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1159
+ FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
1160
+ WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1161
+ """
1162
+ }
lightrag/lightrag.py CHANGED
@@ -83,8 +83,12 @@ ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBS
83
  TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
84
  TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
85
  TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage")
 
 
86
  AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
 
87
  GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage")
 
88
 
89
 
90
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
@@ -295,6 +299,10 @@ class LightRAG:
295
  "Neo4JStorage": Neo4JStorage,
296
  "OracleGraphStorage": OracleGraphStorage,
297
  "AGEStorage": AGEStorage,
 
 
 
 
298
  "TiDBGraphStorage": TiDBGraphStorage,
299
  "GremlinStorage": GremlinStorage,
300
  # "ArangoDBStorage": ArangoDBStorage
 
83
  TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
84
  TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
85
  TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage")
86
+ PGKVStorage = lazy_external_import(".kg.postgres_impl", "PGKVStorage")
87
+ PGVectorStorage = lazy_external_import(".kg.postgres_impl", "PGVectorStorage")
88
  AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
89
+ PGGraphStorage = lazy_external_import(".kg.postgres_impl", "PGGraphStorage")
90
  GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage")
91
+ PGDocStatusStorage = lazy_external_import(".kg.postgres_impl", "PGDocStatusStorage")
92
 
93
 
94
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
 
299
  "Neo4JStorage": Neo4JStorage,
300
  "OracleGraphStorage": OracleGraphStorage,
301
  "AGEStorage": AGEStorage,
302
+ "PGGraphStorage": PGGraphStorage,
303
+ "PGKVStorage": PGKVStorage,
304
+ "PGDocStatusStorage": PGDocStatusStorage,
305
+ "PGVectorStorage": PGVectorStorage,
306
  "TiDBGraphStorage": TiDBGraphStorage,
307
  "GremlinStorage": GremlinStorage,
308
  # "ArangoDBStorage": ArangoDBStorage