Merge pull request #452 from Weaxs/main
Browse filessupport TiDB: add TiDBKVStorage, TiDBVectorDBStorage
- examples/lightrag_tidb_demo.py +127 -0
- lightrag/kg/tidb_impl.py +454 -0
- lightrag/lightrag.py +4 -0
- requirements.txt +4 -1
examples/lightrag_tidb_demo.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from lightrag import LightRAG, QueryParam
|
7 |
+
from lightrag.kg.tidb_impl import TiDB
|
8 |
+
from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache
|
9 |
+
from lightrag.utils import EmbeddingFunc
|
10 |
+
|
11 |
+
WORKING_DIR = "./dickens"
|
12 |
+
|
13 |
+
# We use SiliconCloud API to call LLM on Oracle Cloud
|
14 |
+
# More docs here https://docs.siliconflow.cn/introduction
|
15 |
+
BASE_URL = "https://api.siliconflow.cn/v1/"
|
16 |
+
APIKEY = ""
|
17 |
+
CHATMODEL = ""
|
18 |
+
EMBEDMODEL = ""
|
19 |
+
|
20 |
+
TIDB_HOST = ""
|
21 |
+
TIDB_PORT = ""
|
22 |
+
TIDB_USER = ""
|
23 |
+
TIDB_PASSWORD = ""
|
24 |
+
TIDB_DATABASE = ""
|
25 |
+
|
26 |
+
|
27 |
+
if not os.path.exists(WORKING_DIR):
|
28 |
+
os.mkdir(WORKING_DIR)
|
29 |
+
|
30 |
+
|
31 |
+
async def llm_model_func(
|
32 |
+
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
33 |
+
) -> str:
|
34 |
+
return await openai_complete_if_cache(
|
35 |
+
CHATMODEL,
|
36 |
+
prompt,
|
37 |
+
system_prompt=system_prompt,
|
38 |
+
history_messages=history_messages,
|
39 |
+
api_key=APIKEY,
|
40 |
+
base_url=BASE_URL,
|
41 |
+
**kwargs,
|
42 |
+
)
|
43 |
+
|
44 |
+
|
45 |
+
async def embedding_func(texts: list[str]) -> np.ndarray:
|
46 |
+
return await siliconcloud_embedding(
|
47 |
+
texts,
|
48 |
+
# model=EMBEDMODEL,
|
49 |
+
api_key=APIKEY,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
async def get_embedding_dim():
|
54 |
+
test_text = ["This is a test sentence."]
|
55 |
+
embedding = await embedding_func(test_text)
|
56 |
+
embedding_dim = embedding.shape[1]
|
57 |
+
return embedding_dim
|
58 |
+
|
59 |
+
|
60 |
+
async def main():
|
61 |
+
try:
|
62 |
+
# Detect embedding dimension
|
63 |
+
embedding_dimension = await get_embedding_dim()
|
64 |
+
print(f"Detected embedding dimension: {embedding_dimension}")
|
65 |
+
|
66 |
+
# Create TiDB DB connection
|
67 |
+
tidb = TiDB(
|
68 |
+
config={
|
69 |
+
"host": TIDB_HOST,
|
70 |
+
"port": TIDB_PORT,
|
71 |
+
"user": TIDB_USER,
|
72 |
+
"password": TIDB_PASSWORD,
|
73 |
+
"database": TIDB_DATABASE,
|
74 |
+
"workspace": "company", # specify which docs you want to store and query
|
75 |
+
}
|
76 |
+
)
|
77 |
+
|
78 |
+
# Check if TiDB DB tables exist, if not, tables will be created
|
79 |
+
await tidb.check_tables()
|
80 |
+
|
81 |
+
# Initialize LightRAG
|
82 |
+
# We use TiDB DB as the KV/vector
|
83 |
+
# You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
|
84 |
+
rag = LightRAG(
|
85 |
+
enable_llm_cache=False,
|
86 |
+
working_dir=WORKING_DIR,
|
87 |
+
chunk_token_size=512,
|
88 |
+
llm_model_func=llm_model_func,
|
89 |
+
embedding_func=EmbeddingFunc(
|
90 |
+
embedding_dim=embedding_dimension,
|
91 |
+
max_token_size=512,
|
92 |
+
func=embedding_func,
|
93 |
+
),
|
94 |
+
kv_storage="TiDBKVStorage",
|
95 |
+
vector_storage="TiDBVectorDBStorage",
|
96 |
+
)
|
97 |
+
|
98 |
+
if rag.llm_response_cache:
|
99 |
+
rag.llm_response_cache.db = tidb
|
100 |
+
rag.full_docs.db = tidb
|
101 |
+
rag.text_chunks.db = tidb
|
102 |
+
rag.entities_vdb.db = tidb
|
103 |
+
rag.relationships_vdb.db = tidb
|
104 |
+
rag.chunks_vdb.db = tidb
|
105 |
+
|
106 |
+
# Extract and Insert into LightRAG storage
|
107 |
+
with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
|
108 |
+
await rag.ainsert(f.read())
|
109 |
+
|
110 |
+
# Perform search in different modes
|
111 |
+
modes = ["naive", "local", "global", "hybrid"]
|
112 |
+
for mode in modes:
|
113 |
+
print("=" * 20, mode, "=" * 20)
|
114 |
+
print(
|
115 |
+
await rag.aquery(
|
116 |
+
"What are the top themes in this story?",
|
117 |
+
param=QueryParam(mode=mode),
|
118 |
+
)
|
119 |
+
)
|
120 |
+
print("-" * 100, "\n")
|
121 |
+
|
122 |
+
except Exception as e:
|
123 |
+
print(f"An error occurred: {e}")
|
124 |
+
|
125 |
+
|
126 |
+
if __name__ == "__main__":
|
127 |
+
asyncio.run(main())
|
lightrag/kg/tidb_impl.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
from sqlalchemy import create_engine, text
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from lightrag.base import BaseVectorStorage, BaseKVStorage
|
11 |
+
from lightrag.utils import logger
|
12 |
+
|
13 |
+
|
14 |
+
class TiDB(object):
|
15 |
+
def __init__(self, config, **kwargs):
|
16 |
+
self.host = config.get("host", None)
|
17 |
+
self.port = config.get("port", None)
|
18 |
+
self.user = config.get("user", None)
|
19 |
+
self.password = config.get("password", None)
|
20 |
+
self.database = config.get("database", None)
|
21 |
+
self.workspace = config.get("workspace", None)
|
22 |
+
connection_string = (
|
23 |
+
f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
|
24 |
+
f"?ssl_verify_cert=true&ssl_verify_identity=true"
|
25 |
+
)
|
26 |
+
|
27 |
+
try:
|
28 |
+
self.engine = create_engine(connection_string)
|
29 |
+
logger.info(f"Connected to TiDB database at {self.database}")
|
30 |
+
except Exception as e:
|
31 |
+
logger.error(f"Failed to connect to TiDB database at {self.database}")
|
32 |
+
logger.error(f"TiDB database error: {e}")
|
33 |
+
raise
|
34 |
+
|
35 |
+
async def check_tables(self):
|
36 |
+
for k, v in TABLES.items():
|
37 |
+
try:
|
38 |
+
await self.query(f"SELECT 1 FROM {k}".format(k=k))
|
39 |
+
except Exception as e:
|
40 |
+
logger.error(f"Failed to check table {k} in TiDB database")
|
41 |
+
logger.error(f"TiDB database error: {e}")
|
42 |
+
try:
|
43 |
+
# print(v["ddl"])
|
44 |
+
await self.execute(v["ddl"])
|
45 |
+
logger.info(f"Created table {k} in TiDB database")
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(f"Failed to create table {k} in TiDB database")
|
48 |
+
logger.error(f"TiDB database error: {e}")
|
49 |
+
|
50 |
+
async def query(
|
51 |
+
self, sql: str, params: dict = None, multirows: bool = False
|
52 |
+
) -> Union[dict, None]:
|
53 |
+
if params is None:
|
54 |
+
params = {"workspace": self.workspace}
|
55 |
+
else:
|
56 |
+
params.update({"workspace": self.workspace})
|
57 |
+
with self.engine.connect() as conn, conn.begin():
|
58 |
+
try:
|
59 |
+
result = conn.execute(text(sql), params)
|
60 |
+
except Exception as e:
|
61 |
+
logger.error(f"Tidb database error: {e}")
|
62 |
+
print(sql)
|
63 |
+
print(params)
|
64 |
+
raise
|
65 |
+
if multirows:
|
66 |
+
rows = result.all()
|
67 |
+
if rows:
|
68 |
+
data = [dict(zip(result.keys(), row)) for row in rows]
|
69 |
+
else:
|
70 |
+
data = []
|
71 |
+
else:
|
72 |
+
row = result.first()
|
73 |
+
if row:
|
74 |
+
data = dict(zip(result.keys(), row))
|
75 |
+
else:
|
76 |
+
data = None
|
77 |
+
return data
|
78 |
+
|
79 |
+
async def execute(self, sql: str, data: list | dict = None):
|
80 |
+
# logger.info("go into TiDBDB execute method")
|
81 |
+
try:
|
82 |
+
with self.engine.connect() as conn, conn.begin():
|
83 |
+
if data is None:
|
84 |
+
conn.execute(text(sql))
|
85 |
+
else:
|
86 |
+
conn.execute(text(sql), parameters=data)
|
87 |
+
except Exception as e:
|
88 |
+
logger.error(f"TiDB database error: {e}")
|
89 |
+
print(sql)
|
90 |
+
print(data)
|
91 |
+
raise
|
92 |
+
|
93 |
+
|
94 |
+
@dataclass
|
95 |
+
class TiDBKVStorage(BaseKVStorage):
|
96 |
+
# should pass db object to self.db
|
97 |
+
def __post_init__(self):
|
98 |
+
self._data = {}
|
99 |
+
self._max_batch_size = self.global_config["embedding_batch_num"]
|
100 |
+
|
101 |
+
################ QUERY METHODS ################
|
102 |
+
|
103 |
+
async def get_by_id(self, id: str) -> Union[dict, None]:
|
104 |
+
"""根据 id 获取 doc_full 数据."""
|
105 |
+
SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
|
106 |
+
params = {"id": id}
|
107 |
+
# print("get_by_id:"+SQL)
|
108 |
+
res = await self.db.query(SQL, params)
|
109 |
+
if res:
|
110 |
+
data = res # {"data":res}
|
111 |
+
# print (data)
|
112 |
+
return data
|
113 |
+
else:
|
114 |
+
return None
|
115 |
+
|
116 |
+
# Query by id
|
117 |
+
async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
|
118 |
+
"""根据 id 获取 doc_chunks 数据"""
|
119 |
+
SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
|
120 |
+
ids=",".join([f"'{id}'" for id in ids])
|
121 |
+
)
|
122 |
+
# print("get_by_ids:"+SQL)
|
123 |
+
res = await self.db.query(SQL, multirows=True)
|
124 |
+
if res:
|
125 |
+
data = res # [{"data":i} for i in res]
|
126 |
+
# print(data)
|
127 |
+
return data
|
128 |
+
else:
|
129 |
+
return None
|
130 |
+
|
131 |
+
async def filter_keys(self, keys: list[str]) -> set[str]:
|
132 |
+
"""过滤掉重复内容"""
|
133 |
+
SQL = SQL_TEMPLATES["filter_keys"].format(
|
134 |
+
table_name=N_T[self.namespace],
|
135 |
+
id_field=N_ID[self.namespace],
|
136 |
+
ids=",".join([f"'{id}'" for id in keys]),
|
137 |
+
)
|
138 |
+
try:
|
139 |
+
await self.db.query(SQL)
|
140 |
+
except Exception as e:
|
141 |
+
logger.error(f"Tidb database error: {e}")
|
142 |
+
print(SQL)
|
143 |
+
res = await self.db.query(SQL, multirows=True)
|
144 |
+
if res:
|
145 |
+
exist_keys = [key["id"] for key in res]
|
146 |
+
data = set([s for s in keys if s not in exist_keys])
|
147 |
+
else:
|
148 |
+
exist_keys = []
|
149 |
+
data = set([s for s in keys if s not in exist_keys])
|
150 |
+
return data
|
151 |
+
|
152 |
+
################ INSERT full_doc AND chunks ################
|
153 |
+
async def upsert(self, data: dict[str, dict]):
|
154 |
+
left_data = {k: v for k, v in data.items() if k not in self._data}
|
155 |
+
self._data.update(left_data)
|
156 |
+
if self.namespace == "text_chunks":
|
157 |
+
list_data = [
|
158 |
+
{
|
159 |
+
"__id__": k,
|
160 |
+
**{k1: v1 for k1, v1 in v.items()},
|
161 |
+
}
|
162 |
+
for k, v in data.items()
|
163 |
+
]
|
164 |
+
contents = [v["content"] for v in data.values()]
|
165 |
+
batches = [
|
166 |
+
contents[i : i + self._max_batch_size]
|
167 |
+
for i in range(0, len(contents), self._max_batch_size)
|
168 |
+
]
|
169 |
+
embeddings_list = await asyncio.gather(
|
170 |
+
*[self.embedding_func(batch) for batch in batches]
|
171 |
+
)
|
172 |
+
embeddings = np.concatenate(embeddings_list)
|
173 |
+
for i, d in enumerate(list_data):
|
174 |
+
d["__vector__"] = embeddings[i]
|
175 |
+
|
176 |
+
merge_sql = SQL_TEMPLATES["upsert_chunk"]
|
177 |
+
data = []
|
178 |
+
for item in list_data:
|
179 |
+
data.append(
|
180 |
+
{
|
181 |
+
"id": item["__id__"],
|
182 |
+
"content": item["content"],
|
183 |
+
"tokens": item["tokens"],
|
184 |
+
"chunk_order_index": item["chunk_order_index"],
|
185 |
+
"full_doc_id": item["full_doc_id"],
|
186 |
+
"content_vector": f"{item["__vector__"].tolist()}",
|
187 |
+
"workspace": self.db.workspace,
|
188 |
+
}
|
189 |
+
)
|
190 |
+
await self.db.execute(merge_sql, data)
|
191 |
+
|
192 |
+
if self.namespace == "full_docs":
|
193 |
+
merge_sql = SQL_TEMPLATES["upsert_doc_full"]
|
194 |
+
data = []
|
195 |
+
for k, v in self._data.items():
|
196 |
+
data.append(
|
197 |
+
{
|
198 |
+
"id": k,
|
199 |
+
"content": v["content"],
|
200 |
+
"workspace": self.db.workspace,
|
201 |
+
}
|
202 |
+
)
|
203 |
+
await self.db.execute(merge_sql, data)
|
204 |
+
return left_data
|
205 |
+
|
206 |
+
async def index_done_callback(self):
|
207 |
+
if self.namespace in ["full_docs", "text_chunks"]:
|
208 |
+
logger.info("full doc and chunk data had been saved into TiDB db!")
|
209 |
+
|
210 |
+
|
211 |
+
@dataclass
|
212 |
+
class TiDBVectorDBStorage(BaseVectorStorage):
|
213 |
+
cosine_better_than_threshold: float = 0.2
|
214 |
+
|
215 |
+
def __post_init__(self):
|
216 |
+
self._client_file_name = os.path.join(
|
217 |
+
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
218 |
+
)
|
219 |
+
self._max_batch_size = self.global_config["embedding_batch_num"]
|
220 |
+
self.cosine_better_than_threshold = self.global_config.get(
|
221 |
+
"cosine_better_than_threshold", self.cosine_better_than_threshold
|
222 |
+
)
|
223 |
+
|
224 |
+
async def query(self, query: str, top_k: int) -> list[dict]:
|
225 |
+
"""search from tidb vector"""
|
226 |
+
|
227 |
+
embeddings = await self.embedding_func([query])
|
228 |
+
embedding = embeddings[0]
|
229 |
+
|
230 |
+
embedding_string = "[" + ", ".join(map(str, embedding.tolist())) + "]"
|
231 |
+
|
232 |
+
params = {
|
233 |
+
"embedding_string": embedding_string,
|
234 |
+
"top_k": top_k,
|
235 |
+
"better_than_threshold": self.cosine_better_than_threshold,
|
236 |
+
}
|
237 |
+
|
238 |
+
results = await self.db.query(
|
239 |
+
SQL_TEMPLATES[self.namespace], params=params, multirows=True
|
240 |
+
)
|
241 |
+
print("vector search result:", results)
|
242 |
+
if not results:
|
243 |
+
return []
|
244 |
+
return results
|
245 |
+
|
246 |
+
###### INSERT entities And relationships ######
|
247 |
+
async def upsert(self, data: dict[str, dict]):
|
248 |
+
# ignore, upsert in TiDBKVStorage already
|
249 |
+
if not len(data):
|
250 |
+
logger.warning("You insert an empty data to vector DB")
|
251 |
+
return []
|
252 |
+
if self.namespace == "chunks":
|
253 |
+
return []
|
254 |
+
logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
|
255 |
+
|
256 |
+
list_data = [
|
257 |
+
{
|
258 |
+
"id": k,
|
259 |
+
**{k1: v1 for k1, v1 in v.items()},
|
260 |
+
}
|
261 |
+
for k, v in data.items()
|
262 |
+
]
|
263 |
+
contents = [v["content"] for v in data.values()]
|
264 |
+
batches = [
|
265 |
+
contents[i : i + self._max_batch_size]
|
266 |
+
for i in range(0, len(contents), self._max_batch_size)
|
267 |
+
]
|
268 |
+
embedding_tasks = [self.embedding_func(batch) for batch in batches]
|
269 |
+
embeddings_list = []
|
270 |
+
for f in tqdm(
|
271 |
+
asyncio.as_completed(embedding_tasks),
|
272 |
+
total=len(embedding_tasks),
|
273 |
+
desc="Generating embeddings",
|
274 |
+
unit="batch",
|
275 |
+
):
|
276 |
+
embeddings = await f
|
277 |
+
embeddings_list.append(embeddings)
|
278 |
+
embeddings = np.concatenate(embeddings_list)
|
279 |
+
for i, d in enumerate(list_data):
|
280 |
+
d["content_vector"] = embeddings[i]
|
281 |
+
|
282 |
+
if self.namespace == "entities":
|
283 |
+
data = []
|
284 |
+
for item in list_data:
|
285 |
+
merge_sql = SQL_TEMPLATES["upsert_entity"]
|
286 |
+
data.append(
|
287 |
+
{
|
288 |
+
"id": item["id"],
|
289 |
+
"name": item["entity_name"],
|
290 |
+
"content": item["content"],
|
291 |
+
"content_vector": f"{item["content_vector"].tolist()}",
|
292 |
+
"workspace": self.db.workspace,
|
293 |
+
}
|
294 |
+
)
|
295 |
+
await self.db.execute(merge_sql, data)
|
296 |
+
|
297 |
+
elif self.namespace == "relationships":
|
298 |
+
data = []
|
299 |
+
for item in list_data:
|
300 |
+
merge_sql = SQL_TEMPLATES["upsert_relationship"]
|
301 |
+
data.append(
|
302 |
+
{
|
303 |
+
"id": item["id"],
|
304 |
+
"source_name": item["src_id"],
|
305 |
+
"target_name": item["tgt_id"],
|
306 |
+
"content": item["content"],
|
307 |
+
"content_vector": f"{item["content_vector"].tolist()}",
|
308 |
+
"workspace": self.db.workspace,
|
309 |
+
}
|
310 |
+
)
|
311 |
+
await self.db.execute(merge_sql, data)
|
312 |
+
|
313 |
+
|
314 |
+
N_T = {
|
315 |
+
"full_docs": "LIGHTRAG_DOC_FULL",
|
316 |
+
"text_chunks": "LIGHTRAG_DOC_CHUNKS",
|
317 |
+
"chunks": "LIGHTRAG_DOC_CHUNKS",
|
318 |
+
"entities": "LIGHTRAG_GRAPH_NODES",
|
319 |
+
"relationships": "LIGHTRAG_GRAPH_EDGES",
|
320 |
+
}
|
321 |
+
N_ID = {
|
322 |
+
"full_docs": "doc_id",
|
323 |
+
"text_chunks": "chunk_id",
|
324 |
+
"chunks": "chunk_id",
|
325 |
+
"entities": "entity_id",
|
326 |
+
"relationships": "relation_id",
|
327 |
+
}
|
328 |
+
|
329 |
+
TABLES = {
|
330 |
+
"LIGHTRAG_DOC_FULL": {
|
331 |
+
"ddl": """
|
332 |
+
CREATE TABLE LIGHTRAG_DOC_FULL (
|
333 |
+
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
334 |
+
`doc_id` VARCHAR(256) NOT NULL,
|
335 |
+
`workspace` varchar(1024),
|
336 |
+
`content` LONGTEXT,
|
337 |
+
`meta` JSON,
|
338 |
+
`createtime` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
339 |
+
`updatetime` TIMESTAMP DEFAULT NULL,
|
340 |
+
UNIQUE KEY (`doc_id`)
|
341 |
+
);
|
342 |
+
"""
|
343 |
+
},
|
344 |
+
"LIGHTRAG_DOC_CHUNKS": {
|
345 |
+
"ddl": """
|
346 |
+
CREATE TABLE LIGHTRAG_DOC_CHUNKS (
|
347 |
+
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
348 |
+
`chunk_id` VARCHAR(256) NOT NULL,
|
349 |
+
`full_doc_id` VARCHAR(256) NOT NULL,
|
350 |
+
`workspace` varchar(1024),
|
351 |
+
`chunk_order_index` INT,
|
352 |
+
`tokens` INT,
|
353 |
+
`content` LONGTEXT,
|
354 |
+
`content_vector` VECTOR,
|
355 |
+
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
356 |
+
`updatetime` DATETIME DEFAULT NULL,
|
357 |
+
UNIQUE KEY (`chunk_id`)
|
358 |
+
);
|
359 |
+
"""
|
360 |
+
},
|
361 |
+
"LIGHTRAG_GRAPH_NODES": {
|
362 |
+
"ddl": """
|
363 |
+
CREATE TABLE LIGHTRAG_GRAPH_NODES (
|
364 |
+
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
365 |
+
`entity_id` VARCHAR(256) NOT NULL,
|
366 |
+
`workspace` varchar(1024),
|
367 |
+
`name` VARCHAR(2048),
|
368 |
+
`content` LONGTEXT,
|
369 |
+
`content_vector` VECTOR,
|
370 |
+
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
371 |
+
`updatetime` DATETIME DEFAULT NULL,
|
372 |
+
UNIQUE KEY (`entity_id`)
|
373 |
+
);
|
374 |
+
"""
|
375 |
+
},
|
376 |
+
"LIGHTRAG_GRAPH_EDGES": {
|
377 |
+
"ddl": """
|
378 |
+
CREATE TABLE LIGHTRAG_GRAPH_EDGES (
|
379 |
+
`id` BIGINT PRIMARY KEY AUTO_RANDOM,
|
380 |
+
`relation_id` VARCHAR(256) NOT NULL,
|
381 |
+
`workspace` varchar(1024),
|
382 |
+
`source_name` VARCHAR(2048),
|
383 |
+
`target_name` VARCHAR(2048),
|
384 |
+
`content` LONGTEXT,
|
385 |
+
`content_vector` VECTOR,
|
386 |
+
`createtime` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
387 |
+
`updatetime` DATETIME DEFAULT NULL,
|
388 |
+
UNIQUE KEY (`relation_id`)
|
389 |
+
);
|
390 |
+
"""
|
391 |
+
},
|
392 |
+
"LIGHTRAG_LLM_CACHE": {
|
393 |
+
"ddl": """
|
394 |
+
CREATE TABLE LIGHTRAG_LLM_CACHE (
|
395 |
+
id BIGINT PRIMARY KEY AUTO_INCREMENT,
|
396 |
+
send TEXT,
|
397 |
+
return TEXT,
|
398 |
+
model VARCHAR(1024),
|
399 |
+
createtime DATETIME DEFAULT CURRENT_TIMESTAMP,
|
400 |
+
updatetime DATETIME DEFAULT NULL
|
401 |
+
);
|
402 |
+
"""
|
403 |
+
},
|
404 |
+
}
|
405 |
+
|
406 |
+
|
407 |
+
SQL_TEMPLATES = {
|
408 |
+
# SQL for KVStorage
|
409 |
+
"get_by_id_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id = :id AND workspace = :workspace",
|
410 |
+
"get_by_id_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id = :id AND workspace = :workspace",
|
411 |
+
"get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
|
412 |
+
"get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
|
413 |
+
"filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
|
414 |
+
# SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
|
415 |
+
"upsert_doc_full": """
|
416 |
+
INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
|
417 |
+
VALUES (:id, :content, :workspace)
|
418 |
+
ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
419 |
+
""",
|
420 |
+
"upsert_chunk": """
|
421 |
+
INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
|
422 |
+
VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
|
423 |
+
ON DUPLICATE KEY UPDATE
|
424 |
+
content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
|
425 |
+
full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
426 |
+
""",
|
427 |
+
# SQL for VectorStorage
|
428 |
+
"entities": """SELECT n.name as entity_name FROM
|
429 |
+
(SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
|
430 |
+
FROM LIGHTRAG_GRAPH_NODES WHERE workspace = :workspace) n
|
431 |
+
WHERE n.distance>:better_than_threshold ORDER BY n.distance DESC LIMIT :top_k""",
|
432 |
+
"relationships": """SELECT e.source_name as src_id, e.target_name as tgt_id FROM
|
433 |
+
(SELECT source_name, target_name, VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
|
434 |
+
FROM LIGHTRAG_GRAPH_EDGES WHERE workspace = :workspace) e
|
435 |
+
WHERE e.distance>:better_than_threshold ORDER BY e.distance DESC LIMIT :top_k""",
|
436 |
+
"chunks": """SELECT c.id FROM
|
437 |
+
(SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
|
438 |
+
FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
|
439 |
+
WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""",
|
440 |
+
"upsert_entity": """
|
441 |
+
INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
|
442 |
+
VALUES(:id, :name, :content, :content_vector, :workspace)
|
443 |
+
ON DUPLICATE KEY UPDATE
|
444 |
+
name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
|
445 |
+
workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
446 |
+
""",
|
447 |
+
"upsert_relationship": """
|
448 |
+
INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
|
449 |
+
VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
|
450 |
+
ON DUPLICATE KEY UPDATE
|
451 |
+
source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
|
452 |
+
content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
|
453 |
+
""",
|
454 |
+
}
|
lightrag/lightrag.py
CHANGED
@@ -77,6 +77,8 @@ OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBS
|
|
77 |
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
|
78 |
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
|
79 |
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
|
|
|
|
|
80 |
|
81 |
|
82 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
@@ -260,11 +262,13 @@ class LightRAG:
|
|
260 |
"JsonKVStorage": JsonKVStorage,
|
261 |
"OracleKVStorage": OracleKVStorage,
|
262 |
"MongoKVStorage": MongoKVStorage,
|
|
|
263 |
# vector storage
|
264 |
"NanoVectorDBStorage": NanoVectorDBStorage,
|
265 |
"OracleVectorDBStorage": OracleVectorDBStorage,
|
266 |
"MilvusVectorDBStorge": MilvusVectorDBStorge,
|
267 |
"ChromaVectorDBStorage": ChromaVectorDBStorage,
|
|
|
268 |
# graph storage
|
269 |
"NetworkXStorage": NetworkXStorage,
|
270 |
"Neo4JStorage": Neo4JStorage,
|
|
|
77 |
MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
|
78 |
MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
|
79 |
ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
|
80 |
+
TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
|
81 |
+
TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
|
82 |
|
83 |
|
84 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
|
262 |
"JsonKVStorage": JsonKVStorage,
|
263 |
"OracleKVStorage": OracleKVStorage,
|
264 |
"MongoKVStorage": MongoKVStorage,
|
265 |
+
"TiDBKVStorage": TiDBKVStorage,
|
266 |
# vector storage
|
267 |
"NanoVectorDBStorage": NanoVectorDBStorage,
|
268 |
"OracleVectorDBStorage": OracleVectorDBStorage,
|
269 |
"MilvusVectorDBStorge": MilvusVectorDBStorge,
|
270 |
"ChromaVectorDBStorage": ChromaVectorDBStorage,
|
271 |
+
"TiDBVectorDBStorage": TiDBVectorDBStorage,
|
272 |
# graph storage
|
273 |
"NetworkXStorage": NetworkXStorage,
|
274 |
"Neo4JStorage": Neo4JStorage,
|
requirements.txt
CHANGED
@@ -13,9 +13,12 @@ openai
|
|
13 |
oracledb
|
14 |
pymilvus
|
15 |
pymongo
|
|
|
16 |
pyvis
|
17 |
-
tenacity
|
18 |
# lmdeploy[all]
|
|
|
|
|
|
|
19 |
|
20 |
# LLM packages
|
21 |
tiktoken
|
|
|
13 |
oracledb
|
14 |
pymilvus
|
15 |
pymongo
|
16 |
+
pymysql
|
17 |
pyvis
|
|
|
18 |
# lmdeploy[all]
|
19 |
+
sqlalchemy
|
20 |
+
tenacity
|
21 |
+
|
22 |
|
23 |
# LLM packages
|
24 |
tiktoken
|