zrguo commited on
Commit
1e5c642
·
unverified ·
2 Parent(s): aaf3e3f 2c6a893

Merge pull request #452 from Weaxs/main

Browse files

support TiDB: add TiDBKVStorage, TiDBVectorDBStorage

examples/lightrag_tidb_demo.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+
4
+ import numpy as np
5
+
6
+ from lightrag import LightRAG, QueryParam
7
+ from lightrag.kg.tidb_impl import TiDB
8
+ from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache
9
+ from lightrag.utils import EmbeddingFunc
10
+
11
+ WORKING_DIR = "./dickens"
12
+
13
+ # We use SiliconCloud API to call LLM on Oracle Cloud
14
+ # More docs here https://docs.siliconflow.cn/introduction
15
+ BASE_URL = "https://api.siliconflow.cn/v1/"
16
+ APIKEY = ""
17
+ CHATMODEL = ""
18
+ EMBEDMODEL = ""
19
+
20
+ TIDB_HOST = ""
21
+ TIDB_PORT = ""
22
+ TIDB_USER = ""
23
+ TIDB_PASSWORD = ""
24
+ TIDB_DATABASE = ""
25
+
26
+
27
+ if not os.path.exists(WORKING_DIR):
28
+ os.mkdir(WORKING_DIR)
29
+
30
+
31
+ async def llm_model_func(
32
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
33
+ ) -> str:
34
+ return await openai_complete_if_cache(
35
+ CHATMODEL,
36
+ prompt,
37
+ system_prompt=system_prompt,
38
+ history_messages=history_messages,
39
+ api_key=APIKEY,
40
+ base_url=BASE_URL,
41
+ **kwargs,
42
+ )
43
+
44
+
45
+ async def embedding_func(texts: list[str]) -> np.ndarray:
46
+ return await siliconcloud_embedding(
47
+ texts,
48
+ # model=EMBEDMODEL,
49
+ api_key=APIKEY,
50
+ )
51
+
52
+
53
+ async def get_embedding_dim():
54
+ test_text = ["This is a test sentence."]
55
+ embedding = await embedding_func(test_text)
56
+ embedding_dim = embedding.shape[1]
57
+ return embedding_dim
58
+
59
+
60
+ async def main():
61
+ try:
62
+ # Detect embedding dimension
63
+ embedding_dimension = await get_embedding_dim()
64
+ print(f"Detected embedding dimension: {embedding_dimension}")
65
+
66
+ # Create TiDB DB connection
67
+ tidb = TiDB(
68
+ config={
69
+ "host": TIDB_HOST,
70
+ "port": TIDB_PORT,
71
+ "user": TIDB_USER,
72
+ "password": TIDB_PASSWORD,
73
+ "database": TIDB_DATABASE,
74
+ "workspace": "company", # specify which docs you want to store and query
75
+ }
76
+ )
77
+
78
+ # Check if TiDB DB tables exist, if not, tables will be created
79
+ await tidb.check_tables()
80
+
81
+ # Initialize LightRAG
82
+ # We use TiDB DB as the KV/vector
83
+ # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
84
+ rag = LightRAG(
85
+ enable_llm_cache=False,
86
+ working_dir=WORKING_DIR,
87
+ chunk_token_size=512,
88
+ llm_model_func=llm_model_func,
89
+ embedding_func=EmbeddingFunc(
90
+ embedding_dim=embedding_dimension,
91
+ max_token_size=512,
92
+ func=embedding_func,
93
+ ),
94
+ kv_storage="TiDBKVStorage",
95
+ vector_storage="TiDBVectorDBStorage",
96
+ )
97
+
98
+ if rag.llm_response_cache:
99
+ rag.llm_response_cache.db = tidb
100
+ rag.full_docs.db = tidb
101
+ rag.text_chunks.db = tidb
102
+ rag.entities_vdb.db = tidb
103
+ rag.relationships_vdb.db = tidb
104
+ rag.chunks_vdb.db = tidb
105
+
106
+ # Extract and Insert into LightRAG storage
107
+ with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
108
+ await rag.ainsert(f.read())
109
+
110
+ # Perform search in different modes
111
+ modes = ["naive", "local", "global", "hybrid"]
112
+ for mode in modes:
113
+ print("=" * 20, mode, "=" * 20)
114
+ print(
115
+ await rag.aquery(
116
+ "What are the top themes in this story?",
117
+ param=QueryParam(mode=mode),
118
+ )
119
+ )
120
+ print("-" * 100, "\n")
121
+
122
+ except Exception as e:
123
+ print(f"An error occurred: {e}")
124
+
125
+
126
+ if __name__ == "__main__":
127
+ asyncio.run(main())
lightrag/kg/tidb_impl.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ from dataclasses import dataclass
4
+ from typing import Union
5
+
6
+ import numpy as np
7
+ from sqlalchemy import create_engine, text
8
+ from tqdm import tqdm
9
+
10
+ from lightrag.base import BaseVectorStorage, BaseKVStorage
11
+ from lightrag.utils import logger
12
+
13
+
14
+ class TiDB(object):
15
+ def __init__(self, config, **kwargs):
16
+ self.host = config.get("host", None)
17
+ self.port = config.get("port", None)
18
+ self.user = config.get("user", None)
19
+ self.password = config.get("password", None)
20
+ self.database = config.get("database", None)
21
+ self.workspace = config.get("workspace", None)
22
+ connection_string = (
23
+ f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
24
+ f"?ssl_verify_cert=true&ssl_verify_identity=true"
25
+ )
26
+
27
+ try:
28
+ self.engine = create_engine(connection_string)
29
+ logger.info(f"Connected to TiDB database at {self.database}")
30
+ except Exception as e:
31
+ logger.error(f"Failed to connect to TiDB database at {self.database}")
32
+ logger.error(f"TiDB database error: {e}")
33
+ raise
34
+
35
+ async def check_tables(self):
36
+ for k, v in TABLES.items():
37
+ try:
38
+ await self.query(f"SELECT 1 FROM {k}".format(k=k))
39
+ except Exception as e:
40
+ logger.error(f"Failed to check table {k} in TiDB database")
41
+ logger.error(f"TiDB database error: {e}")
42
+ try:
43
+ # print(v["ddl"])
44
+ await self.execute(v["ddl"])
45
+ logger.info(f"Created table {k} in TiDB database")
46
+ except Exception as e:
47
+ logger.error(f"Failed to create table {k} in TiDB database")
48
+ logger.error(f"TiDB database error: {e}")
49
+
50
+ async def query(
51
+ self, sql: str, params: dict = None, multirows: bool = False
52
+ ) -> Union[dict, None]:
53
+ if params is None:
54
+ params = {"workspace": self.workspace}
55
+ else:
56
+ params.update({"workspace": self.workspace})
57
+ with self.engine.connect() as conn, conn.begin():
58
+ try:
59
+ result = conn.execute(text(sql), params)
60
+ except Exception as e:
61
+ logger.error(f"Tidb database error: {e}")
62
+ print(sql)
63
+ print(params)
64
+ raise
65
+ if multirows:
66
+ rows = result.all()
67
+ if rows:
68
+ data = [dict(zip(result.keys(), row)) for row in rows]
69
+ else:
70
+ data = []
71
+ else:
72
+ row = result.first()
73
+ if row:
74
+ data = dict(zip(result.keys(), row))
75
+ else:
76
+ data = None
77
+ return data
78
+
79
+ async def execute(self, sql: str, data: list | dict = None):
80
+ # logger.info("go into TiDBDB execute method")
81
+ try:
82
+ with self.engine.connect() as conn, conn.begin():
83
+ if data is None:
84
+ conn.execute(text(sql))
85
+ else:
86
+ conn.execute(text(sql), parameters=data)
87
+ except Exception as e:
88
+ logger.error(f"TiDB database error: {e}")
89
+ print(sql)
90
+ print(data)
91
+ raise
92
+
93
+
94
+ @dataclass
95
+ class TiDBKVStorage(BaseKVStorage):
96
+ # should pass db object to self.db
97
+ def __post_init__(self):
98
+ self._data = {}
99
+ self._max_batch_size = self.global_config["embedding_batch_num"]
100
+
101
+ ################ QUERY METHODS ################
102
+
103
+ async def get_by_id(self, id: str) -> Union[dict, None]:
104
+ """根据 id 获取 doc_full 数据."""
105
+ SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
106
+ params = {"id": id}
107
+ # print("get_by_id:"+SQL)
108
+ res = await self.db.query(SQL, params)
109
+ if res:
110
+ data = res # {"data":res}
111
+ # print (data)
112
+ return data
113
+ else:
114
+ return None
115
+
116
+ # Query by id
117
+ async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
118
+ """根据 id 获取 doc_chunks 数据"""
119
+ SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
120
+ ids=",".join([f"'{id}'" for id in ids])
121
+ )
122
+ # print("get_by_ids:"+SQL)
123
+ res = await self.db.query(SQL, multirows=True)
124
+ if res:
125
+ data = res # [{"data":i} for i in res]
126
+ # print(data)
127
+ return data
128
+ else:
129
+ return None
130
+
131
+ async def filter_keys(self, keys: list[str]) -> set[str]:
132
+ """过滤掉重复内容"""
133
+ SQL = SQL_TEMPLATES["filter_keys"].format(
134
+ table_name=N_T[self.namespace],
135
+ id_field=N_ID[self.namespace],
136
+ ids=",".join([f"'{id}'" for id in keys]),
137
+ )
138
+ try:
139
+ await self.db.query(SQL)
140
+ except Exception as e:
141
+ logger.error(f"Tidb database error: {e}")
142
+ print(SQL)
143
+ res = await self.db.query(SQL, multirows=True)
144
+ if res:
145
+ exist_keys = [key["id"] for key in res]
146
+ data = set([s for s in keys if s not in exist_keys])
147
+ else:
148
+ exist_keys = []
149
+ data = set([s for s in keys if s not in exist_keys])
150
+ return data
151
+
152
+ ################ INSERT full_doc AND chunks ################
153
+ async def upsert(self, data: dict[str, dict]):
154
+ left_data = {k: v for k, v in data.items() if k not in self._data}
155
+ self._data.update(left_data)
156
+ if self.namespace == "text_chunks":
157
+ list_data = [
158
+ {
159
+ "__id__": k,
160
+ **{k1: v1 for k1, v1 in v.items()},
161
+ }
162
+ for k, v in data.items()
163
+ ]
164
+ contents = [v["content"] for v in data.values()]
165
+ batches = [
166
+ contents[i : i + self._max_batch_size]
167
+ for i in range(0, len(contents), self._max_batch_size)
168
+ ]
169
+ embeddings_list = await asyncio.gather(
170
+ *[self.embedding_func(batch) for batch in batches]
171
+ )
172
+ embeddings = np.concatenate(embeddings_list)
173
+ for i, d in enumerate(list_data):
174
+ d["__vector__"] = embeddings[i]
175
+
176
+ merge_sql = SQL_TEMPLATES["upsert_chunk"]
177
+ data = []
178
+ for item in list_data:
179
+ data.append(
180
+ {
181
+ "id": item["__id__"],
182
+ "content": item["content"],
183
+ "tokens": item["tokens"],
184
+ "chunk_order_index": item["chunk_order_index"],
185
+ "full_doc_id": item["full_doc_id"],
186
+ "content_vector": f"{item["__vector__"].tolist()}",
187
+ "workspace": self.db.workspace,
188
+ }
189
+ )
190
+ await self.db.execute(merge_sql, data)
191
+
192
+ if self.namespace == "full_docs":
193
+ merge_sql = SQL_TEMPLATES["upsert_doc_full"]
194
+ data = []
195
+ for k, v in self._data.items():
196
+ data.append(
197
+ {
198
+ "id": k,
199
+ "content": v["content"],
200
+ "workspace": self.db.workspace,
201
+ }
202
+ )
203
+ await self.db.execute(merge_sql, data)
204
+ return left_data
205
+
206
+ async def index_done_callback(self):
207
+ if self.namespace in ["full_docs", "text_chunks"]:
208
+ logger.info("full doc and chunk data had been saved into TiDB db!")
209
+
210
+
211
+ @dataclass
212
+ class TiDBVectorDBStorage(BaseVectorStorage):
213
+ cosine_better_than_threshold: float = 0.2
214
+
215
+ def __post_init__(self):
216
+ self._client_file_name = os.path.join(
217
+ self.global_config["working_dir"], f"vdb_{self.namespace}.json"
218
+ )
219
+ self._max_batch_size = self.global_config["embedding_batch_num"]
220
+ self.cosine_better_than_threshold = self.global_config.get(
221
+ "cosine_better_than_threshold", self.cosine_better_than_threshold
222
+ )
223
+
224
+ async def query(self, query: str, top_k: int) -> list[dict]:
225
+ """search from tidb vector"""
226
+
227
+ embeddings = await self.embedding_func([query])
228
+ embedding = embeddings[0]
229
+
230
+ embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
231
+
232
+ params = {
233
+ "embedding_string": embedding_string,
234
+ "top_k": top_k,
235
+ "better_than_threshold": self.cosine_better_than_threshold,
236
+ }
237
+
238
+ results = await self.db.query(
239
+ SQL_TEMPLATES[self.namespace], params=params, multirows=True
240
+ )
241
+ print("vector search result:", results)
242
+ if not results:
243
+ return []
244
+ return results
245
+
246
+ ###### INSERT entities And relationships ######
247
+ async def upsert(self, data: dict[str, dict]):
248
+ # ignore, upsert in TiDBKVStorage already
249
+ if not len(data):
250
+ logger.warning("You insert an empty data to vector DB")
251
+ return []
252
+ if self.namespace == "chunks":
253
+ return []
254
+ logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
255
+
256
+ list_data = [
257
+ {
258
+ "id": k,
259
+ **{k1: v1 for k1, v1 in v.items()},
260
+ }
261
+ for k, v in data.items()
262
+ ]
263
+ contents = [v["content"] for v in data.values()]
264
+ batches = [
265
+ contents[i : i + self._max_batch_size]
266
+ for i in range(0, len(contents), self._max_batch_size)
267
+ ]
268
+ embedding_tasks = [self.embedding_func(batch) for batch in batches]
269
+ embeddings_list = []
270
+ for f in tqdm(
271
+ asyncio.as_completed(embedding_tasks),
272
+ total=len(embedding_tasks),
273
+ desc="Generating embeddings",
274
+ unit="batch",
275
+ ):
276
+ embeddings = await f
277
+ embeddings_list.append(embeddings)
278
+ embeddings = np.concatenate(embeddings_list)
279
+ for i, d in enumerate(list_data):
280
+ d["content_vector"] = embeddings[i]
281
+
282
+ if self.namespace == "entities":
283
+ data = []
284
+ for item in list_data:
285
+ merge_sql = SQL_TEMPLATES["upsert_entity"]
286
+ data.append(
287
+ {
288
+ "id": item["id"],
289
+ "name": item["entity_name"],
290
+ "content": item["content"],
291
+ "content_vector": f"{item["content_vector"].tolist()}",
292
+ "workspace": self.db.workspace,
293
+ }
294
+ )
295
+ await self.db.execute(merge_sql, data)
296
+
297
+ elif self.namespace == "relationships":
298
+ data = []
299
+ for item in list_data:
300
+ merge_sql = SQL_TEMPLATES["upsert_relationship"]
301
+ data.append(
302
+ {
303
+ "id": item["id"],
304
+ "source_name": item["src_id"],
305
+ "target_name": item["tgt_id"],
306
+ "content": item["content"],
307
+ "content_vector": f"{item["content_vector"].tolist()}",
308
+ "workspace": self.db.workspace,
309
+ }
310
+ )
311
+ await self.db.execute(merge_sql, data)
312
+
313
+
314
+ N_T = {
315
+ "full_docs": "LIGHTRAG_DOC_FULL",
316
+ "text_chunks": "LIGHTRAG_DOC_CHUNKS",
317
+ "chunks": "LIGHTRAG_DOC_CHUNKS",
318
+ "entities": "LIGHTRAG_GRAPH_NODES",
319
+ "relationships": "LIGHTRAG_GRAPH_EDGES",
320
+ }
321
+ N_ID = {
322
+ "full_docs": "doc_id",
323
+ "text_chunks": "chunk_id",
324
+ "chunks": "chunk_id",
325
+ "entities": "entity_id",
326
+ "relationships": "relation_id",
327
+ }
328
+
329
+ TABLES = {
330
+ "LIGHTRAG_DOC_FULL": {
331
+ "ddl": """
332
+ CREATE TABLE LIGHTRAG_DOC_FULL (
333
+ `id` BIGINT PRIMARY KEY AUTO_RANDOM,
334
+ `doc_id` VARCHAR(256) NOT NULL,
335
+ `workspace` varchar(1024),
336
+ `content` LONGTEXT,
337
+ `meta` JSON,
338
+ `createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
339
+ `updatetime` TIMESTAMP DEFAULT NULL,
340
+ UNIQUE KEY (`doc_id`)
341
+ );
342
+ """
343
+ },
344
+ "LIGHTRAG_DOC_CHUNKS": {
345
+ "ddl": """
346
+ CREATE TABLE LIGHTRAG_DOC_CHUNKS (
347
+ `id` BIGINT PRIMARY KEY AUTO_RANDOM,
348
+ `chunk_id` VARCHAR(256) NOT NULL,
349
+ `full_doc_id` VARCHAR(256) NOT NULL,
350
+ `workspace` varchar(1024),
351
+ `chunk_order_index` INT,
352
+ `tokens` INT,
353
+ `content` LONGTEXT,
354
+ `content_vector` VECTOR,
355
+ `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
356
+ `updatetime` DATETIME DEFAULT NULL,
357
+ UNIQUE KEY (`chunk_id`)
358
+ );
359
+ """
360
+ },
361
+ "LIGHTRAG_GRAPH_NODES": {
362
+ "ddl": """
363
+ CREATE TABLE LIGHTRAG_GRAPH_NODES (
364
+ `id` BIGINT PRIMARY KEY AUTO_RANDOM,
365
+ `entity_id` VARCHAR(256) NOT NULL,
366
+ `workspace` varchar(1024),
367
+ `name` VARCHAR(2048),
368
+ `content` LONGTEXT,
369
+ `content_vector` VECTOR,
370
+ `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
371
+ `updatetime` DATETIME DEFAULT NULL,
372
+ UNIQUE KEY (`entity_id`)
373
+ );
374
+ """
375
+ },
376
+ "LIGHTRAG_GRAPH_EDGES": {
377
+ "ddl": """
378
+ CREATE TABLE LIGHTRAG_GRAPH_EDGES (
379
+ `id` BIGINT PRIMARY KEY AUTO_RANDOM,
380
+ `relation_id` VARCHAR(256) NOT NULL,
381
+ `workspace` varchar(1024),
382
+ `source_name` VARCHAR(2048),
383
+ `target_name` VARCHAR(2048),
384
+ `content` LONGTEXT,
385
+ `content_vector` VECTOR,
386
+ `createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
387
+ `updatetime` DATETIME DEFAULT NULL,
388
+ UNIQUE KEY (`relation_id`)
389
+ );
390
+ """
391
+ },
392
+ "LIGHTRAG_LLM_CACHE": {
393
+ "ddl": """
394
+ CREATE TABLE LIGHTRAG_LLM_CACHE (
395
+ id BIGINT PRIMARY KEY AUTO_INCREMENT,
396
+ send TEXT,
397
+ return TEXT,
398
+ model VARCHAR(1024),
399
+ createtime DATETIME DEFAULT CURRENT_TIMESTAMP,
400
+ updatetime DATETIME DEFAULT NULL
401
+ );
402
+ """
403
+ },
404
+ }
405
+
406
+
407
+ SQL_TEMPLATES = {
408
+ # SQL for KVStorage
409
+ "get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace",
410
+ "get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace",
411
+ "get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
412
+ "get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
413
+ "filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
414
+ # SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
415
+ "upsert_doc_full": """
416
+ INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
417
+ VALUES (:id, :content, :workspace)
418
+ ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
419
+ """,
420
+ "upsert_chunk": """
421
+ INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
422
+ VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
423
+ ON DUPLICATE KEY UPDATE
424
+ content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
425
+ full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
426
+ """,
427
+ # SQL for VectorStorage
428
+ "entities": """SELECT n.name as entity_name FROM
429
+ (SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
430
+ FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n
431
+ WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k""",
432
+ "relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id FROM
433
+ (SELECT source_name, target_name, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
434
+ FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e
435
+ WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k""",
436
+ "chunks": """SELECT c.id FROM
437
+ (SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
438
+ FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
439
+ WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""",
440
+ "upsert_entity": """
441
+ INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
442
+ VALUES(:id, :name, :content, :content_vector, :workspace)
443
+ ON DUPLICATE KEY UPDATE
444
+ name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
445
+ workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
446
+ """,
447
+ "upsert_relationship": """
448
+ INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
449
+ VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
450
+ ON DUPLICATE KEY UPDATE
451
+ source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
452
+ content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
453
+ """,
454
+ }
lightrag/lightrag.py CHANGED
@@ -77,6 +77,8 @@ OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBS
77
  MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
78
  MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
79
  ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
 
 
80
 
81
 
82
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
@@ -260,11 +262,13 @@ class LightRAG:
260
  "JsonKVStorage": JsonKVStorage,
261
  "OracleKVStorage": OracleKVStorage,
262
  "MongoKVStorage": MongoKVStorage,
 
263
  # vector storage
264
  "NanoVectorDBStorage": NanoVectorDBStorage,
265
  "OracleVectorDBStorage": OracleVectorDBStorage,
266
  "MilvusVectorDBStorge": MilvusVectorDBStorge,
267
  "ChromaVectorDBStorage": ChromaVectorDBStorage,
 
268
  # graph storage
269
  "NetworkXStorage": NetworkXStorage,
270
  "Neo4JStorage": Neo4JStorage,
 
77
  MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
78
  MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
79
  ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
80
+ TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
81
+ TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
82
 
83
 
84
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
 
262
  "JsonKVStorage": JsonKVStorage,
263
  "OracleKVStorage": OracleKVStorage,
264
  "MongoKVStorage": MongoKVStorage,
265
+ "TiDBKVStorage": TiDBKVStorage,
266
  # vector storage
267
  "NanoVectorDBStorage": NanoVectorDBStorage,
268
  "OracleVectorDBStorage": OracleVectorDBStorage,
269
  "MilvusVectorDBStorge": MilvusVectorDBStorge,
270
  "ChromaVectorDBStorage": ChromaVectorDBStorage,
271
+ "TiDBVectorDBStorage": TiDBVectorDBStorage,
272
  # graph storage
273
  "NetworkXStorage": NetworkXStorage,
274
  "Neo4JStorage": Neo4JStorage,
requirements.txt CHANGED
@@ -13,9 +13,12 @@ openai
13
  oracledb
14
  pymilvus
15
  pymongo
 
16
  pyvis
17
- tenacity
18
  # lmdeploy[all]
 
 
 
19
 
20
  # LLM packages
21
  tiktoken
 
13
  oracledb
14
  pymilvus
15
  pymongo
16
+ pymysql
17
  pyvis
 
18
  # lmdeploy[all]
19
+ sqlalchemy
20
+ tenacity
21
+
22
 
23
  # LLM packages
24
  tiktoken