Weaxs commited on
Commit
2c6a893
·
1 Parent(s): c87e4c5

pre-commit fix tidb

Browse files
lightrag/kg/tidb_impl.py CHANGED
@@ -19,8 +19,10 @@ class TiDB(object):
19
  self.password = config.get("password", None)
20
  self.database = config.get("database", None)
21
  self.workspace = config.get("workspace", None)
22
- connection_string = (f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
23
- f"?ssl_verify_cert=true&ssl_verify_identity=true")
 
 
24
 
25
  try:
26
  self.engine = create_engine(connection_string)
@@ -49,7 +51,7 @@ class TiDB(object):
49
  self, sql: str, params: dict = None, multirows: bool = False
50
  ) -> Union[dict, None]:
51
  if params is None:
52
- params = { "workspace": self.workspace }
53
  else:
54
  params.update({"workspace": self.workspace})
55
  with self.engine.connect() as conn, conn.begin():
@@ -130,8 +132,8 @@ class TiDBKVStorage(BaseKVStorage):
130
  """过滤掉重复内容"""
131
  SQL = SQL_TEMPLATES["filter_keys"].format(
132
  table_name=N_T[self.namespace],
133
- id_field= N_ID[self.namespace],
134
- ids=",".join([f"'{id}'" for id in keys])
135
  )
136
  try:
137
  await self.db.query(SQL)
@@ -161,7 +163,7 @@ class TiDBKVStorage(BaseKVStorage):
161
  ]
162
  contents = [v["content"] for v in data.values()]
163
  batches = [
164
- contents[i: i + self._max_batch_size]
165
  for i in range(0, len(contents), self._max_batch_size)
166
  ]
167
  embeddings_list = await asyncio.gather(
@@ -174,26 +176,30 @@ class TiDBKVStorage(BaseKVStorage):
174
  merge_sql = SQL_TEMPLATES["upsert_chunk"]
175
  data = []
176
  for item in list_data:
177
- data.append({
178
- "id": item["__id__"],
179
- "content": item["content"],
180
- "tokens": item["tokens"],
181
- "chunk_order_index": item["chunk_order_index"],
182
- "full_doc_id": item["full_doc_id"],
183
- "content_vector": f"{item["__vector__"].tolist()}",
184
- "workspace": self.db.workspace,
185
- })
 
 
186
  await self.db.execute(merge_sql, data)
187
 
188
  if self.namespace == "full_docs":
189
  merge_sql = SQL_TEMPLATES["upsert_doc_full"]
190
  data = []
191
  for k, v in self._data.items():
192
- data.append({
193
- "id": k,
194
- "content": v["content"],
195
- "workspace": self.db.workspace,
196
- })
 
 
197
  await self.db.execute(merge_sql, data)
198
  return left_data
199
 
@@ -201,6 +207,7 @@ class TiDBKVStorage(BaseKVStorage):
201
  if self.namespace in ["full_docs", "text_chunks"]:
202
  logger.info("full doc and chunk data had been saved into TiDB db!")
203
 
 
204
  @dataclass
205
  class TiDBVectorDBStorage(BaseVectorStorage):
206
  cosine_better_than_threshold: float = 0.2
@@ -215,7 +222,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
215
  )
216
 
217
  async def query(self, query: str, top_k: int) -> list[dict]:
218
- """ search from tidb vector"""
219
 
220
  embeddings = await self.embedding_func([query])
221
  embedding = embeddings[0]
@@ -228,8 +235,10 @@ class TiDBVectorDBStorage(BaseVectorStorage):
228
  "better_than_threshold": self.cosine_better_than_threshold,
229
  }
230
 
231
- results = await self.db.query(SQL_TEMPLATES[self.namespace], params=params, multirows=True)
232
- print("vector search result:",results)
 
 
233
  if not results:
234
  return []
235
  return results
@@ -253,16 +262,16 @@ class TiDBVectorDBStorage(BaseVectorStorage):
253
  ]
254
  contents = [v["content"] for v in data.values()]
255
  batches = [
256
- contents[i: i + self._max_batch_size]
257
  for i in range(0, len(contents), self._max_batch_size)
258
  ]
259
  embedding_tasks = [self.embedding_func(batch) for batch in batches]
260
  embeddings_list = []
261
  for f in tqdm(
262
- asyncio.as_completed(embedding_tasks),
263
- total=len(embedding_tasks),
264
- desc="Generating embeddings",
265
- unit="batch",
266
  ):
267
  embeddings = await f
268
  embeddings_list.append(embeddings)
@@ -274,27 +283,31 @@ class TiDBVectorDBStorage(BaseVectorStorage):
274
  data = []
275
  for item in list_data:
276
  merge_sql = SQL_TEMPLATES["upsert_entity"]
277
- data.append({
278
- "id": item["id"],
279
- "name": item["entity_name"],
280
- "content": item["content"],
281
- "content_vector": f"{item["content_vector"].tolist()}",
282
- "workspace": self.db.workspace,
283
- })
 
 
284
  await self.db.execute(merge_sql, data)
285
 
286
  elif self.namespace == "relationships":
287
  data = []
288
  for item in list_data:
289
  merge_sql = SQL_TEMPLATES["upsert_relationship"]
290
- data.append({
291
- "id": item["id"],
292
- "source_name": item["src_id"],
293
- "target_name": item["tgt_id"],
294
- "content": item["content"],
295
- "content_vector": f"{item["content_vector"].tolist()}",
296
- "workspace": self.db.workspace,
297
- })
 
 
298
  await self.db.execute(merge_sql, data)
299
 
300
 
@@ -346,8 +359,7 @@ TABLES = {
346
  """
347
  },
348
  "LIGHTRAG_GRAPH_NODES": {
349
- "ddl":
350
- """
351
  CREATE TABLE LIGHTRAG_GRAPH_NODES (
352
  `id` BIGINT PRIMARY KEY AUTO_RANDOM,
353
  `entity_id` VARCHAR(256) NOT NULL,
@@ -362,8 +374,7 @@ TABLES = {
362
  """
363
  },
364
  "LIGHTRAG_GRAPH_EDGES": {
365
- "ddl":
366
- """
367
  CREATE TABLE LIGHTRAG_GRAPH_EDGES (
368
  `id` BIGINT PRIMARY KEY AUTO_RANDOM,
369
  `relation_id` VARCHAR(256) NOT NULL,
@@ -400,7 +411,6 @@ SQL_TEMPLATES = {
400
  "get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
401
  "get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
402
  "filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
403
-
404
  # SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
405
  "upsert_doc_full": """
406
  INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
@@ -408,13 +418,12 @@ SQL_TEMPLATES = {
408
  ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
409
  """,
410
  "upsert_chunk": """
411
- INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
412
  VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
413
- ON DUPLICATE KEY UPDATE
414
- content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
415
  full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
416
  """,
417
-
418
  # SQL for VectorStorage
419
  "entities": """SELECT n.name as entity_name FROM
420
  (SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
@@ -428,19 +437,18 @@ SQL_TEMPLATES = {
428
  (SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
429
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
430
  WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""",
431
-
432
  "upsert_entity": """
433
- INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
434
- VALUES(:id, :name, :content, :content_vector, :workspace)
435
- ON DUPLICATE KEY UPDATE
436
- name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
437
  workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
438
  """,
439
  "upsert_relationship": """
440
- INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
441
  VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
442
- ON DUPLICATE KEY UPDATE
443
- source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
444
  content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
445
- """
446
- }
 
19
  self.password = config.get("password", None)
20
  self.database = config.get("database", None)
21
  self.workspace = config.get("workspace", None)
22
+ connection_string = (
23
+ f"mysql+pymysql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
24
+ f"?ssl_verify_cert=true&ssl_verify_identity=true"
25
+ )
26
 
27
  try:
28
  self.engine = create_engine(connection_string)
 
51
  self, sql: str, params: dict = None, multirows: bool = False
52
  ) -> Union[dict, None]:
53
  if params is None:
54
+ params = {"workspace": self.workspace}
55
  else:
56
  params.update({"workspace": self.workspace})
57
  with self.engine.connect() as conn, conn.begin():
 
132
  """过滤掉重复内容"""
133
  SQL = SQL_TEMPLATES["filter_keys"].format(
134
  table_name=N_T[self.namespace],
135
+ id_field=N_ID[self.namespace],
136
+ ids=",".join([f"'{id}'" for id in keys]),
137
  )
138
  try:
139
  await self.db.query(SQL)
 
163
  ]
164
  contents = [v["content"] for v in data.values()]
165
  batches = [
166
+ contents[i : i + self._max_batch_size]
167
  for i in range(0, len(contents), self._max_batch_size)
168
  ]
169
  embeddings_list = await asyncio.gather(
 
176
  merge_sql = SQL_TEMPLATES["upsert_chunk"]
177
  data = []
178
  for item in list_data:
179
+ data.append(
180
+ {
181
+ "id": item["__id__"],
182
+ "content": item["content"],
183
+ "tokens": item["tokens"],
184
+ "chunk_order_index": item["chunk_order_index"],
185
+ "full_doc_id": item["full_doc_id"],
186
+ "content_vector": f"{item["__vector__"].tolist()}",
187
+ "workspace": self.db.workspace,
188
+ }
189
+ )
190
  await self.db.execute(merge_sql, data)
191
 
192
  if self.namespace == "full_docs":
193
  merge_sql = SQL_TEMPLATES["upsert_doc_full"]
194
  data = []
195
  for k, v in self._data.items():
196
+ data.append(
197
+ {
198
+ "id": k,
199
+ "content": v["content"],
200
+ "workspace": self.db.workspace,
201
+ }
202
+ )
203
  await self.db.execute(merge_sql, data)
204
  return left_data
205
 
 
207
  if self.namespace in ["full_docs", "text_chunks"]:
208
  logger.info("full doc and chunk data had been saved into TiDB db!")
209
 
210
+
211
  @dataclass
212
  class TiDBVectorDBStorage(BaseVectorStorage):
213
  cosine_better_than_threshold: float = 0.2
 
222
  )
223
 
224
  async def query(self, query: str, top_k: int) -> list[dict]:
225
+ """search from tidb vector"""
226
 
227
  embeddings = await self.embedding_func([query])
228
  embedding = embeddings[0]
 
235
  "better_than_threshold": self.cosine_better_than_threshold,
236
  }
237
 
238
+ results = await self.db.query(
239
+ SQL_TEMPLATES[self.namespace], params=params, multirows=True
240
+ )
241
+ print("vector search result:", results)
242
  if not results:
243
  return []
244
  return results
 
262
  ]
263
  contents = [v["content"] for v in data.values()]
264
  batches = [
265
+ contents[i : i + self._max_batch_size]
266
  for i in range(0, len(contents), self._max_batch_size)
267
  ]
268
  embedding_tasks = [self.embedding_func(batch) for batch in batches]
269
  embeddings_list = []
270
  for f in tqdm(
271
+ asyncio.as_completed(embedding_tasks),
272
+ total=len(embedding_tasks),
273
+ desc="Generating embeddings",
274
+ unit="batch",
275
  ):
276
  embeddings = await f
277
  embeddings_list.append(embeddings)
 
283
  data = []
284
  for item in list_data:
285
  merge_sql = SQL_TEMPLATES["upsert_entity"]
286
+ data.append(
287
+ {
288
+ "id": item["id"],
289
+ "name": item["entity_name"],
290
+ "content": item["content"],
291
+ "content_vector": f"{item["content_vector"].tolist()}",
292
+ "workspace": self.db.workspace,
293
+ }
294
+ )
295
  await self.db.execute(merge_sql, data)
296
 
297
  elif self.namespace == "relationships":
298
  data = []
299
  for item in list_data:
300
  merge_sql = SQL_TEMPLATES["upsert_relationship"]
301
+ data.append(
302
+ {
303
+ "id": item["id"],
304
+ "source_name": item["src_id"],
305
+ "target_name": item["tgt_id"],
306
+ "content": item["content"],
307
+ "content_vector": f"{item["content_vector"].tolist()}",
308
+ "workspace": self.db.workspace,
309
+ }
310
+ )
311
  await self.db.execute(merge_sql, data)
312
 
313
 
 
359
  """
360
  },
361
  "LIGHTRAG_GRAPH_NODES": {
362
+ "ddl": """
 
363
  CREATE TABLE LIGHTRAG_GRAPH_NODES (
364
  `id` BIGINT PRIMARY KEY AUTO_RANDOM,
365
  `entity_id` VARCHAR(256) NOT NULL,
 
374
  """
375
  },
376
  "LIGHTRAG_GRAPH_EDGES": {
377
+ "ddl": """
 
378
  CREATE TABLE LIGHTRAG_GRAPH_EDGES (
379
  `id` BIGINT PRIMARY KEY AUTO_RANDOM,
380
  `relation_id` VARCHAR(256) NOT NULL,
 
411
  "get_by_ids_full_docs": "SELECT doc_id as id, IFNULL(content, '') AS content FROM LIGHTRAG_DOC_FULL WHERE doc_id IN ({ids}) AND workspace = :workspace",
412
  "get_by_ids_text_chunks": "SELECT chunk_id as id, tokens, IFNULL(content, '') AS content, chunk_order_index, full_doc_id FROM LIGHTRAG_DOC_CHUNKS WHERE chunk_id IN ({ids}) AND workspace = :workspace",
413
  "filter_keys": "SELECT {id_field} AS id FROM {table_name} WHERE {id_field} IN ({ids}) AND workspace = :workspace",
 
414
  # SQL for Merge operations (TiDB version with INSERT ... ON DUPLICATE KEY UPDATE)
415
  "upsert_doc_full": """
416
  INSERT INTO LIGHTRAG_DOC_FULL (doc_id, content, workspace)
 
418
  ON DUPLICATE KEY UPDATE content = VALUES(content), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
419
  """,
420
  "upsert_chunk": """
421
+ INSERT INTO LIGHTRAG_DOC_CHUNKS(chunk_id, content, tokens, chunk_order_index, full_doc_id, content_vector, workspace)
422
  VALUES (:id, :content, :tokens, :chunk_order_index, :full_doc_id, :content_vector, :workspace)
423
+ ON DUPLICATE KEY UPDATE
424
+ content = VALUES(content), tokens = VALUES(tokens), chunk_order_index = VALUES(chunk_order_index),
425
  full_doc_id = VALUES(full_doc_id), content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
426
  """,
 
427
  # SQL for VectorStorage
428
  "entities": """SELECT n.name as entity_name FROM
429
  (SELECT entity_id as id, name, VEC_COSINE_DISTANCE(content_vector,:embedding_string) as distance
 
437
  (SELECT chunk_id as id,VEC_COSINE_DISTANCE(content_vector, :embedding_string) as distance
438
  FROM LIGHTRAG_DOC_CHUNKS WHERE workspace = :workspace) c
439
  WHERE c.distance>:better_than_threshold ORDER BY c.distance DESC LIMIT :top_k""",
 
440
  "upsert_entity": """
441
+ INSERT INTO LIGHTRAG_GRAPH_NODES(entity_id, name, content, content_vector, workspace)
442
+ VALUES(:id, :name, :content, :content_vector, :workspace)
443
+ ON DUPLICATE KEY UPDATE
444
+ name = VALUES(name), content = VALUES(content), content_vector = VALUES(content_vector),
445
  workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
446
  """,
447
  "upsert_relationship": """
448
+ INSERT INTO LIGHTRAG_GRAPH_EDGES(relation_id, source_name, target_name, content, content_vector, workspace)
449
  VALUES(:id, :source_name, :target_name, :content, :content_vector, :workspace)
450
+ ON DUPLICATE KEY UPDATE
451
+ source_name = VALUES(source_name), target_name = VALUES(target_name), content = VALUES(content),
452
  content_vector = VALUES(content_vector), workspace = VALUES(workspace), updatetime = CURRENT_TIMESTAMP
453
+ """,
454
+ }
lightrag/lightrag.py CHANGED
@@ -80,6 +80,7 @@ ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBS
80
  TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
81
  TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
82
 
 
83
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
84
  """
85
  Ensure that there is always an event loop available.
 
80
  TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
81
  TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
82
 
83
+
84
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
85
  """
86
  Ensure that there is always an event loop available.
requirements.txt CHANGED
@@ -13,11 +13,11 @@ openai
13
  oracledb
14
  pymilvus
15
  pymongo
 
16
  pyvis
17
- tenacity
18
  # lmdeploy[all]
19
  sqlalchemy
20
- pymysql
21
 
22
 
23
  # LLM packages
 
13
  oracledb
14
  pymilvus
15
  pymongo
16
+ pymysql
17
  pyvis
 
18
  # lmdeploy[all]
19
  sqlalchemy
20
+ tenacity
21
 
22
 
23
  # LLM packages