zrguo commited on
Commit
ca31e34
·
unverified ·
2 Parent(s): 2639305 121ced9

Merge pull request #590 from jin38324/main

Browse files

Enhance Robustness of insert Method with Pipeline Processing and Caching Mechanisms

.DS_Store DELETED
Binary file (8.2 kB)
 
.gitignore CHANGED
@@ -21,3 +21,4 @@ rag_storage
21
  venv/
22
  examples/input/
23
  examples/output/
 
 
21
  venv/
22
  examples/input/
23
  examples/output/
24
+ .DS_Store
get_all_edges_nx.py → examples/get_all_edges_nx.py RENAMED
File without changes
examples/lightrag_oracle_demo.py CHANGED
@@ -20,7 +20,8 @@ BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
20
  APIKEY = "ocigenerativeai"
21
  CHATMODEL = "cohere.command-r-plus"
22
  EMBEDMODEL = "cohere.embed-multilingual-v3.0"
23
-
 
24
 
25
  if not os.path.exists(WORKING_DIR):
26
  os.mkdir(WORKING_DIR)
@@ -86,30 +87,46 @@ async def main():
86
  # We use Oracle DB as the KV/vector/graph storage
87
  # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
88
  rag = LightRAG(
89
- enable_llm_cache=False,
90
  working_dir=WORKING_DIR,
91
- chunk_token_size=512,
 
 
 
 
 
92
  llm_model_func=llm_model_func,
93
  embedding_func=EmbeddingFunc(
94
  embedding_dim=embedding_dimension,
95
- max_token_size=512,
96
  func=embedding_func,
97
  ),
98
  graph_storage="OracleGraphStorage",
99
  kv_storage="OracleKVStorage",
100
  vector_storage="OracleVectorDBStorage",
 
 
 
 
 
 
101
  )
102
 
103
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
104
- rag.graph_storage_cls.db = oracle_db
105
- rag.key_string_value_json_storage_cls.db = oracle_db
106
- rag.vector_db_storage_cls.db = oracle_db
107
- # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
108
- rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
109
 
110
  # Extract and Insert into LightRAG storage
111
- with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
112
- await rag.ainsert(f.read())
 
 
 
 
 
 
 
 
 
113
 
114
  # Perform search in different modes
115
  modes = ["naive", "local", "global", "hybrid"]
 
20
  APIKEY = "ocigenerativeai"
21
  CHATMODEL = "cohere.command-r-plus"
22
  EMBEDMODEL = "cohere.embed-multilingual-v3.0"
23
+ CHUNK_TOKEN_SIZE = 1024
24
+ MAX_TOKENS = 4000
25
 
26
  if not os.path.exists(WORKING_DIR):
27
  os.mkdir(WORKING_DIR)
 
87
  # We use Oracle DB as the KV/vector/graph storage
88
  # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
89
  rag = LightRAG(
90
+ # log_level="DEBUG",
91
  working_dir=WORKING_DIR,
92
+ entity_extract_max_gleaning=1,
93
+ enable_llm_cache=True,
94
+ enable_llm_cache_for_entity_extract=True,
95
+ embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90},
96
+ chunk_token_size=CHUNK_TOKEN_SIZE,
97
+ llm_model_max_token_size=MAX_TOKENS,
98
  llm_model_func=llm_model_func,
99
  embedding_func=EmbeddingFunc(
100
  embedding_dim=embedding_dimension,
101
+ max_token_size=500,
102
  func=embedding_func,
103
  ),
104
  graph_storage="OracleGraphStorage",
105
  kv_storage="OracleKVStorage",
106
  vector_storage="OracleVectorDBStorage",
107
+ addon_params={
108
+ "example_number": 1,
109
+ "language": "Simplfied Chinese",
110
+ "entity_types": ["organization", "person", "geo", "event"],
111
+ "insert_batch_size": 2,
112
+ },
113
  )
114
 
115
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
116
+ rag.set_storage_client(db_client=oracle_db)
 
 
 
 
117
 
118
  # Extract and Insert into LightRAG storage
119
+ with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
120
+ all_text = f.read()
121
+ texts = [x for x in all_text.split("\n") if x]
122
+
123
+ # New mode use pipeline
124
+ await rag.apipeline_process_documents(texts)
125
+ await rag.apipeline_process_chunks()
126
+ await rag.apipeline_process_extract_graph()
127
+
128
+ # Old method use ainsert
129
+ # await rag.ainsert(texts)
130
 
131
  # Perform search in different modes
132
  modes = ["naive", "local", "global", "hybrid"]
test.py → examples/test.py RENAMED
File without changes
test_chromadb.py → examples/test_chromadb.py RENAMED
File without changes
test_neo4j.py → examples/test_neo4j.py RENAMED
File without changes
lightrag/kg/oracle_impl.py CHANGED
@@ -153,8 +153,6 @@ class OracleDB:
153
  if data is None:
154
  await cursor.execute(sql)
155
  else:
156
- # print(data)
157
- # print(sql)
158
  await cursor.execute(sql, data)
159
  await connection.commit()
160
  except Exception as e:
@@ -167,35 +165,64 @@ class OracleDB:
167
  @dataclass
168
  class OracleKVStorage(BaseKVStorage):
169
  # should pass db object to self.db
 
 
 
170
  def __post_init__(self):
171
  self._data = {}
172
- self._max_batch_size = self.global_config["embedding_batch_num"]
173
 
174
  ################ QUERY METHODS ################
175
 
176
  async def get_by_id(self, id: str) -> Union[dict, None]:
177
- """根据 id 获取 doc_full 数据."""
178
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
179
  params = {"workspace": self.db.workspace, "id": id}
180
  # print("get_by_id:"+SQL)
181
- res = await self.db.query(SQL, params)
 
 
 
 
 
 
182
  if res:
183
- data = res # {"data":res}
184
- # print (data)
185
- return data
 
 
 
 
 
 
 
 
 
 
 
186
  else:
187
  return None
188
 
189
- # Query by id
190
  async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
191
- """根据 id 获取 doc_chunks 数据"""
192
  SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
193
  ids=",".join([f"'{id}'" for id in ids])
194
  )
195
  params = {"workspace": self.db.workspace}
196
  # print("get_by_ids:"+SQL)
197
- # print(params)
198
  res = await self.db.query(SQL, params, multirows=True)
 
 
 
 
 
 
 
 
 
 
 
199
  if res:
200
  data = res # [{"data":i} for i in res]
201
  # print(data)
@@ -203,38 +230,43 @@ class OracleKVStorage(BaseKVStorage):
203
  else:
204
  return None
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  async def filter_keys(self, keys: list[str]) -> set[str]:
207
- """过滤掉重复内容"""
208
  SQL = SQL_TEMPLATES["filter_keys"].format(
209
  table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
210
  )
211
  params = {"workspace": self.db.workspace}
212
- try:
213
- await self.db.query(SQL, params)
214
- except Exception as e:
215
- logger.error(f"Oracle database error: {e}")
216
- print(SQL)
217
- print(params)
218
  res = await self.db.query(SQL, params, multirows=True)
219
- data = None
220
  if res:
221
  exist_keys = [key["id"] for key in res]
222
  data = set([s for s in keys if s not in exist_keys])
 
223
  else:
224
- exist_keys = []
225
- data = set([s for s in keys if s not in exist_keys])
226
- return data
227
 
228
  ################ INSERT METHODS ################
229
  async def upsert(self, data: dict[str, dict]):
230
- left_data = {k: v for k, v in data.items() if k not in self._data}
231
- self._data.update(left_data)
232
- # print(self._data)
233
- # values = []
234
  if self.namespace == "text_chunks":
235
  list_data = [
236
  {
237
- "__id__": k,
238
  **{k1: v1 for k1, v1 in v.items()},
239
  }
240
  for k, v in data.items()
@@ -250,35 +282,50 @@ class OracleKVStorage(BaseKVStorage):
250
  embeddings = np.concatenate(embeddings_list)
251
  for i, d in enumerate(list_data):
252
  d["__vector__"] = embeddings[i]
253
- # print(list_data)
 
254
  for item in list_data:
255
- merge_sql = SQL_TEMPLATES["merge_chunk"]
256
- data = {
257
- "check_id": item["__id__"],
258
- "id": item["__id__"],
259
  "content": item["content"],
260
  "workspace": self.db.workspace,
261
  "tokens": item["tokens"],
262
  "chunk_order_index": item["chunk_order_index"],
263
  "full_doc_id": item["full_doc_id"],
264
  "content_vector": item["__vector__"],
 
265
  }
266
- # print(merge_sql)
267
- await self.db.execute(merge_sql, data)
268
-
269
  if self.namespace == "full_docs":
270
- for k, v in self._data.items():
271
  # values.clear()
272
  merge_sql = SQL_TEMPLATES["merge_doc_full"]
273
- data = {
274
- "check_id": k,
275
  "id": k,
276
  "content": v["content"],
277
  "workspace": self.db.workspace,
278
  }
279
- # print(merge_sql)
280
- await self.db.execute(merge_sql, data)
281
- return left_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  async def index_done_callback(self):
284
  if self.namespace in ["full_docs", "text_chunks"]:
@@ -287,6 +334,8 @@ class OracleKVStorage(BaseKVStorage):
287
 
288
  @dataclass
289
  class OracleVectorDBStorage(BaseVectorStorage):
 
 
290
  cosine_better_than_threshold: float = 0.2
291
 
292
  def __post_init__(self):
@@ -328,7 +377,7 @@ class OracleGraphStorage(BaseGraphStorage):
328
 
329
  def __post_init__(self):
330
  """从graphml文件加载图"""
331
- self._max_batch_size = self.global_config["embedding_batch_num"]
332
 
333
  #################### insert method ################
334
 
@@ -362,7 +411,6 @@ class OracleGraphStorage(BaseGraphStorage):
362
  "content": content,
363
  "content_vector": content_vector,
364
  }
365
- # print(merge_sql)
366
  await self.db.execute(merge_sql, data)
367
  # self._graph.add_node(node_id, **node_data)
368
 
@@ -564,20 +612,26 @@ N_T = {
564
  TABLES = {
565
  "LIGHTRAG_DOC_FULL": {
566
  "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
567
- id varchar(256)PRIMARY KEY,
568
  workspace varchar(1024),
569
  doc_name varchar(1024),
570
  content CLOB,
571
  meta JSON,
 
 
 
 
572
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
573
- updatetime TIMESTAMP DEFAULT NULL
 
574
  )"""
575
  },
576
  "LIGHTRAG_DOC_CHUNKS": {
577
  "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
578
- id varchar(256) PRIMARY KEY,
579
  workspace varchar(1024),
580
  full_doc_id varchar(256),
 
581
  chunk_order_index NUMBER,
582
  tokens NUMBER,
583
  content CLOB,
@@ -619,9 +673,15 @@ TABLES = {
619
  "LIGHTRAG_LLM_CACHE": {
620
  "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
621
  id varchar(256) PRIMARY KEY,
622
- send clob,
623
- return clob,
624
- model varchar(1024),
 
 
 
 
 
 
625
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
626
  updatetime TIMESTAMP DEFAULT NULL
627
  )"""
@@ -646,23 +706,44 @@ TABLES = {
646
 
647
  SQL_TEMPLATES = {
648
  # SQL for KVStorage
649
- "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
650
- "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
651
- "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})",
652
- "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
 
 
 
 
 
 
 
 
 
 
653
  "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
654
- "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
655
- USING DUAL
656
- ON (a.id = :check_id)
657
- WHEN NOT MATCHED THEN
658
- INSERT(id,content,workspace) values(:id,:content,:workspace)
659
- """,
660
- "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
661
- USING DUAL
662
- ON (a.id = :check_id)
663
- WHEN NOT MATCHED THEN
664
- INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
665
- values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """,
 
 
 
 
 
 
 
 
 
 
 
666
  # SQL for VectorStorage
667
  "entities": """SELECT name as entity_name FROM
668
  (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
@@ -714,16 +795,22 @@ SQL_TEMPLATES = {
714
  COLUMNS (a.name as source_name,b.name as target_name))""",
715
  "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
716
  USING DUAL
717
- ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id)
718
  WHEN NOT MATCHED THEN
719
  INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
720
- values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """,
 
 
 
721
  "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
722
  USING DUAL
723
- ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
724
  WHEN NOT MATCHED THEN
725
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
726
- values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
 
 
 
727
  "get_all_nodes": """WITH t0 AS (
728
  SELECT name AS id, entity_type AS label, entity_type, description,
729
  '["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
 
153
  if data is None:
154
  await cursor.execute(sql)
155
  else:
 
 
156
  await cursor.execute(sql, data)
157
  await connection.commit()
158
  except Exception as e:
 
165
  @dataclass
166
  class OracleKVStorage(BaseKVStorage):
167
  # should pass db object to self.db
168
+ db: OracleDB = None
169
+ meta_fields = None
170
+
171
  def __post_init__(self):
172
  self._data = {}
173
+ self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
174
 
175
  ################ QUERY METHODS ################
176
 
177
  async def get_by_id(self, id: str) -> Union[dict, None]:
178
+ """get doc_full data based on id."""
179
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
180
  params = {"workspace": self.db.workspace, "id": id}
181
  # print("get_by_id:"+SQL)
182
+ if "llm_response_cache" == self.namespace:
183
+ array_res = await self.db.query(SQL, params, multirows=True)
184
+ res = {}
185
+ for row in array_res:
186
+ res[row["id"]] = row
187
+ else:
188
+ res = await self.db.query(SQL, params)
189
  if res:
190
+ return res
191
+ else:
192
+ return None
193
+
194
+ async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
195
+ """Specifically for llm_response_cache."""
196
+ SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
197
+ params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
198
+ if "llm_response_cache" == self.namespace:
199
+ array_res = await self.db.query(SQL, params, multirows=True)
200
+ res = {}
201
+ for row in array_res:
202
+ res[row["id"]] = row
203
+ return res
204
  else:
205
  return None
206
 
 
207
  async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
208
+ """get doc_chunks data based on id"""
209
  SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
210
  ids=",".join([f"'{id}'" for id in ids])
211
  )
212
  params = {"workspace": self.db.workspace}
213
  # print("get_by_ids:"+SQL)
 
214
  res = await self.db.query(SQL, params, multirows=True)
215
+ if "llm_response_cache" == self.namespace:
216
+ modes = set()
217
+ dict_res: dict[str, dict] = {}
218
+ for row in res:
219
+ modes.add(row["mode"])
220
+ for mode in modes:
221
+ if mode not in dict_res:
222
+ dict_res[mode] = {}
223
+ for row in res:
224
+ dict_res[row["mode"]][row["id"]] = row
225
+ res = [{k: v} for k, v in dict_res.items()]
226
  if res:
227
  data = res # [{"data":i} for i in res]
228
  # print(data)
 
230
  else:
231
  return None
232
 
233
+ async def get_by_status_and_ids(
234
+ self, status: str, ids: list[str]
235
+ ) -> Union[list[dict], None]:
236
+ """Specifically for llm_response_cache."""
237
+ if ids is not None:
238
+ SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
239
+ ids=",".join([f"'{id}'" for id in ids])
240
+ )
241
+ else:
242
+ SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
243
+ params = {"workspace": self.db.workspace, "status": status}
244
+ res = await self.db.query(SQL, params, multirows=True)
245
+ if res:
246
+ return res
247
+ else:
248
+ return None
249
+
250
  async def filter_keys(self, keys: list[str]) -> set[str]:
251
+ """Return keys that don't exist in storage"""
252
  SQL = SQL_TEMPLATES["filter_keys"].format(
253
  table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
254
  )
255
  params = {"workspace": self.db.workspace}
 
 
 
 
 
 
256
  res = await self.db.query(SQL, params, multirows=True)
 
257
  if res:
258
  exist_keys = [key["id"] for key in res]
259
  data = set([s for s in keys if s not in exist_keys])
260
+ return data
261
  else:
262
+ return set(keys)
 
 
263
 
264
  ################ INSERT METHODS ################
265
  async def upsert(self, data: dict[str, dict]):
 
 
 
 
266
  if self.namespace == "text_chunks":
267
  list_data = [
268
  {
269
+ "id": k,
270
  **{k1: v1 for k1, v1 in v.items()},
271
  }
272
  for k, v in data.items()
 
282
  embeddings = np.concatenate(embeddings_list)
283
  for i, d in enumerate(list_data):
284
  d["__vector__"] = embeddings[i]
285
+
286
+ merge_sql = SQL_TEMPLATES["merge_chunk"]
287
  for item in list_data:
288
+ _data = {
289
+ "id": item["id"],
 
 
290
  "content": item["content"],
291
  "workspace": self.db.workspace,
292
  "tokens": item["tokens"],
293
  "chunk_order_index": item["chunk_order_index"],
294
  "full_doc_id": item["full_doc_id"],
295
  "content_vector": item["__vector__"],
296
+ "status": item["status"],
297
  }
298
+ await self.db.execute(merge_sql, _data)
 
 
299
  if self.namespace == "full_docs":
300
+ for k, v in data.items():
301
  # values.clear()
302
  merge_sql = SQL_TEMPLATES["merge_doc_full"]
303
+ _data = {
 
304
  "id": k,
305
  "content": v["content"],
306
  "workspace": self.db.workspace,
307
  }
308
+ await self.db.execute(merge_sql, _data)
309
+
310
+ if self.namespace == "llm_response_cache":
311
+ for mode, items in data.items():
312
+ for k, v in items.items():
313
+ upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
314
+ _data = {
315
+ "workspace": self.db.workspace,
316
+ "id": k,
317
+ "original_prompt": v["original_prompt"],
318
+ "return_value": v["return"],
319
+ "cache_mode": mode,
320
+ }
321
+
322
+ await self.db.execute(upsert_sql, _data)
323
+ return None
324
+
325
+ async def change_status(self, id: str, status: str):
326
+ SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace])
327
+ params = {"workspace": self.db.workspace, "id": id, "status": status}
328
+ await self.db.execute(SQL, params)
329
 
330
  async def index_done_callback(self):
331
  if self.namespace in ["full_docs", "text_chunks"]:
 
334
 
335
  @dataclass
336
  class OracleVectorDBStorage(BaseVectorStorage):
337
+ # should pass db object to self.db
338
+ db: OracleDB = None
339
  cosine_better_than_threshold: float = 0.2
340
 
341
  def __post_init__(self):
 
377
 
378
  def __post_init__(self):
379
  """从graphml文件加载图"""
380
+ self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
381
 
382
  #################### insert method ################
383
 
 
411
  "content": content,
412
  "content_vector": content_vector,
413
  }
 
414
  await self.db.execute(merge_sql, data)
415
  # self._graph.add_node(node_id, **node_data)
416
 
 
612
  TABLES = {
613
  "LIGHTRAG_DOC_FULL": {
614
  "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
615
+ id varchar(256),
616
  workspace varchar(1024),
617
  doc_name varchar(1024),
618
  content CLOB,
619
  meta JSON,
620
+ content_summary varchar(1024),
621
+ content_length NUMBER,
622
+ status varchar(256),
623
+ chunks_count NUMBER,
624
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
625
+ updatetime TIMESTAMP DEFAULT NULL,
626
+ error varchar(4096)
627
  )"""
628
  },
629
  "LIGHTRAG_DOC_CHUNKS": {
630
  "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
631
+ id varchar(256),
632
  workspace varchar(1024),
633
  full_doc_id varchar(256),
634
+ status varchar(256),
635
  chunk_order_index NUMBER,
636
  tokens NUMBER,
637
  content CLOB,
 
673
  "LIGHTRAG_LLM_CACHE": {
674
  "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
675
  id varchar(256) PRIMARY KEY,
676
+ workspace varchar(1024),
677
+ cache_mode varchar(256),
678
+ model_name varchar(256),
679
+ original_prompt clob,
680
+ return_value clob,
681
+ embedding CLOB,
682
+ embedding_shape NUMBER,
683
+ embedding_min NUMBER,
684
+ embedding_max NUMBER,
685
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
686
  updatetime TIMESTAMP DEFAULT NULL
687
  )"""
 
706
 
707
  SQL_TEMPLATES = {
708
  # SQL for KVStorage
709
+ "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
710
+ "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
711
+ "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
712
+ FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
713
+ "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
714
+ FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""",
715
+ "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
716
+ FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
717
+ "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
718
+ "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
719
+ "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
720
+ "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
721
+ "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
722
+ "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
723
  "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
724
+ "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
725
+ "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
726
+ USING DUAL
727
+ ON (a.id = :id and a.workspace = :workspace)
728
+ WHEN NOT MATCHED THEN
729
+ INSERT(id,content,workspace) values(:id,:content,:workspace)""",
730
+ "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
731
+ USING DUAL
732
+ ON (id = :id and workspace = :workspace)
733
+ WHEN NOT MATCHED THEN INSERT
734
+ (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
735
+ values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
736
+ "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
737
+ USING DUAL
738
+ ON (a.id = :id)
739
+ WHEN NOT MATCHED THEN
740
+ INSERT (workspace,id,original_prompt,return_value,cache_mode)
741
+ VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode)
742
+ WHEN MATCHED THEN UPDATE
743
+ SET original_prompt = :original_prompt,
744
+ return_value = :return_value,
745
+ cache_mode = :cache_mode,
746
+ updatetime = SYSDATE""",
747
  # SQL for VectorStorage
748
  "entities": """SELECT name as entity_name FROM
749
  (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
 
795
  COLUMNS (a.name as source_name,b.name as target_name))""",
796
  "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
797
  USING DUAL
798
+ ON (a.workspace=:workspace and a.name=:name)
799
  WHEN NOT MATCHED THEN
800
  INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
801
+ values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector)
802
+ WHEN MATCHED THEN
803
+ UPDATE SET
804
+ entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
805
  "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
806
  USING DUAL
807
+ ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name)
808
  WHEN NOT MATCHED THEN
809
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
810
+ values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector)
811
+ WHEN MATCHED THEN
812
+ UPDATE SET
813
+ weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
814
  "get_all_nodes": """WITH t0 AS (
815
  SELECT name AS id, entity_type AS label, entity_type, description,
816
  '["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
lightrag/lightrag.py CHANGED
@@ -28,6 +28,7 @@ from .utils import (
28
  convert_response_to_json,
29
  logger,
30
  set_logger,
 
31
  )
32
  from .base import (
33
  BaseGraphStorage,
@@ -38,21 +39,30 @@ from .base import (
38
  DocStatus,
39
  )
40
 
41
- from .storage import (
42
- JsonKVStorage,
43
- NanoVectorDBStorage,
44
- NetworkXStorage,
45
- JsonDocStatusStorage,
46
- )
47
-
48
  from .prompt import GRAPH_FIELD_SEP
49
 
50
-
51
- # future KG integrations
52
-
53
- # from .kg.ArangoDB_impl import (
54
- # GraphStorage as ArangoDBStorage
55
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
 
58
  def lazy_external_import(module_name: str, class_name: str):
@@ -68,34 +78,13 @@ def lazy_external_import(module_name: str, class_name: str):
68
  def import_class(*args, **kwargs):
69
  import importlib
70
 
71
- # Import the module using importlib
72
  module = importlib.import_module(module_name, package=package)
73
-
74
- # Get the class from the module and instantiate it
75
  cls = getattr(module, class_name)
76
  return cls(*args, **kwargs)
77
 
78
  return import_class
79
 
80
 
81
- Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage")
82
- OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage")
83
- OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage")
84
- OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
85
- MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
86
- MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
87
- ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
88
- TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
89
- TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
90
- TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage")
91
- PGKVStorage = lazy_external_import(".kg.postgres_impl", "PGKVStorage")
92
- PGVectorStorage = lazy_external_import(".kg.postgres_impl", "PGVectorStorage")
93
- AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
94
- PGGraphStorage = lazy_external_import(".kg.postgres_impl", "PGGraphStorage")
95
- GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage")
96
- PGDocStatusStorage = lazy_external_import(".kg.postgres_impl", "PGDocStatusStorage")
97
-
98
-
99
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
100
  """
101
  Ensure that there is always an event loop available.
@@ -199,34 +188,51 @@ class LightRAG:
199
  logger.setLevel(self.log_level)
200
 
201
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
 
 
 
202
 
203
- _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
 
 
204
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
205
 
206
- # @TODO: should move all storage setup here to leverage initial start params attached to self.
 
 
 
207
 
 
208
  self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
209
- self._get_storage_class()[self.kv_storage]
210
  )
211
- self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
212
  self.vector_storage
213
- ]
214
- self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
215
  self.graph_storage
216
- ]
217
 
218
- if not os.path.exists(self.working_dir):
219
- logger.info(f"Creating working directory {self.working_dir}")
220
- os.makedirs(self.working_dir)
221
 
222
- self.llm_response_cache = self.key_string_value_json_storage_cls(
223
- namespace="llm_response_cache",
224
- global_config=asdict(self),
 
 
 
 
 
 
 
225
  embedding_func=None,
226
  )
227
 
228
- self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
229
- self.embedding_func
 
230
  )
231
 
232
  ####
@@ -234,17 +240,14 @@ class LightRAG:
234
  ####
235
  self.full_docs = self.key_string_value_json_storage_cls(
236
  namespace="full_docs",
237
- global_config=asdict(self),
238
  embedding_func=self.embedding_func,
239
  )
240
  self.text_chunks = self.key_string_value_json_storage_cls(
241
  namespace="text_chunks",
242
- global_config=asdict(self),
243
  embedding_func=self.embedding_func,
244
  )
245
  self.chunk_entity_relation_graph = self.graph_storage_cls(
246
  namespace="chunk_entity_relation",
247
- global_config=asdict(self),
248
  embedding_func=self.embedding_func,
249
  )
250
  ####
@@ -253,72 +256,69 @@ class LightRAG:
253
 
254
  self.entities_vdb = self.vector_db_storage_cls(
255
  namespace="entities",
256
- global_config=asdict(self),
257
  embedding_func=self.embedding_func,
258
  meta_fields={"entity_name"},
259
  )
260
  self.relationships_vdb = self.vector_db_storage_cls(
261
  namespace="relationships",
262
- global_config=asdict(self),
263
  embedding_func=self.embedding_func,
264
  meta_fields={"src_id", "tgt_id"},
265
  )
266
  self.chunks_vdb = self.vector_db_storage_cls(
267
  namespace="chunks",
268
- global_config=asdict(self),
269
  embedding_func=self.embedding_func,
270
  )
271
 
 
 
 
 
 
 
 
 
 
 
272
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
273
  partial(
274
  self.llm_model_func,
275
- hashing_kv=self.llm_response_cache
276
- if self.llm_response_cache
277
- and hasattr(self.llm_response_cache, "global_config")
278
- else self.key_string_value_json_storage_cls(
279
- namespace="llm_response_cache",
280
- global_config=asdict(self),
281
- embedding_func=None,
282
- ),
283
  **self.llm_model_kwargs,
284
  )
285
  )
286
 
287
  # Initialize document status storage
288
- self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage]
289
  self.doc_status = self.doc_status_storage_cls(
290
  namespace="doc_status",
291
- global_config=asdict(self),
292
  embedding_func=None,
293
  )
294
 
295
- def _get_storage_class(self) -> dict:
296
- return {
297
- # kv storage
298
- "JsonKVStorage": JsonKVStorage,
299
- "OracleKVStorage": OracleKVStorage,
300
- "MongoKVStorage": MongoKVStorage,
301
- "TiDBKVStorage": TiDBKVStorage,
302
- # vector storage
303
- "NanoVectorDBStorage": NanoVectorDBStorage,
304
- "OracleVectorDBStorage": OracleVectorDBStorage,
305
- "MilvusVectorDBStorge": MilvusVectorDBStorge,
306
- "ChromaVectorDBStorage": ChromaVectorDBStorage,
307
- "TiDBVectorDBStorage": TiDBVectorDBStorage,
308
- # graph storage
309
- "NetworkXStorage": NetworkXStorage,
310
- "Neo4JStorage": Neo4JStorage,
311
- "OracleGraphStorage": OracleGraphStorage,
312
- "AGEStorage": AGEStorage,
313
- "PGGraphStorage": PGGraphStorage,
314
- "PGKVStorage": PGKVStorage,
315
- "PGDocStatusStorage": PGDocStatusStorage,
316
- "PGVectorStorage": PGVectorStorage,
317
- "TiDBGraphStorage": TiDBGraphStorage,
318
- "GremlinStorage": GremlinStorage,
319
- # "ArangoDBStorage": ArangoDBStorage
320
- "JsonDocStatusStorage": JsonDocStatusStorage,
321
- }
322
 
323
  def insert(
324
  self, string_or_strings, split_by_character=None, split_by_character_only=False
@@ -540,6 +540,195 @@ class LightRAG:
540
  if update_storage:
541
  await self._insert_done()
542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  async def _insert_done(self):
544
  tasks = []
545
  for storage_inst in [
 
28
  convert_response_to_json,
29
  logger,
30
  set_logger,
31
+ statistic_data,
32
  )
33
  from .base import (
34
  BaseGraphStorage,
 
39
  DocStatus,
40
  )
41
 
 
 
 
 
 
 
 
42
  from .prompt import GRAPH_FIELD_SEP
43
 
44
+ STORAGES = {
45
+ "JsonKVStorage": ".storage",
46
+ "NanoVectorDBStorage": ".storage",
47
+ "NetworkXStorage": ".storage",
48
+ "JsonDocStatusStorage": ".storage",
49
+ "Neo4JStorage": ".kg.neo4j_impl",
50
+ "OracleKVStorage": ".kg.oracle_impl",
51
+ "OracleGraphStorage": ".kg.oracle_impl",
52
+ "OracleVectorDBStorage": ".kg.oracle_impl",
53
+ "MilvusVectorDBStorge": ".kg.milvus_impl",
54
+ "MongoKVStorage": ".kg.mongo_impl",
55
+ "ChromaVectorDBStorage": ".kg.chroma_impl",
56
+ "TiDBKVStorage": ".kg.tidb_impl",
57
+ "TiDBVectorDBStorage": ".kg.tidb_impl",
58
+ "TiDBGraphStorage": ".kg.tidb_impl",
59
+ "PGKVStorage": ".kg.postgres_impl",
60
+ "PGVectorStorage": ".kg.postgres_impl",
61
+ "AGEStorage": ".kg.age_impl",
62
+ "PGGraphStorage": ".kg.postgres_impl",
63
+ "GremlinStorage": ".kg.gremlin_impl",
64
+ "PGDocStatusStorage": ".kg.postgres_impl",
65
+ }
66
 
67
 
68
  def lazy_external_import(module_name: str, class_name: str):
 
78
  def import_class(*args, **kwargs):
79
  import importlib
80
 
 
81
  module = importlib.import_module(module_name, package=package)
 
 
82
  cls = getattr(module, class_name)
83
  return cls(*args, **kwargs)
84
 
85
  return import_class
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
89
  """
90
  Ensure that there is always an event loop available.
 
188
  logger.setLevel(self.log_level)
189
 
190
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
191
+ if not os.path.exists(self.working_dir):
192
+ logger.info(f"Creating working directory {self.working_dir}")
193
+ os.makedirs(self.working_dir)
194
 
195
+ # show config
196
+ global_config = asdict(self)
197
+ _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
198
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
199
 
200
+ # Init LLM
201
+ self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
202
+ self.embedding_func
203
+ )
204
 
205
+ # Initialize all storages
206
  self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
207
+ self._get_storage_class(self.kv_storage)
208
  )
209
+ self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
210
  self.vector_storage
211
+ )
212
+ self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
213
  self.graph_storage
214
+ )
215
 
216
+ self.key_string_value_json_storage_cls = partial(
217
+ self.key_string_value_json_storage_cls, global_config=global_config
218
+ )
219
 
220
+ self.vector_db_storage_cls = partial(
221
+ self.vector_db_storage_cls, global_config=global_config
222
+ )
223
+
224
+ self.graph_storage_cls = partial(
225
+ self.graph_storage_cls, global_config=global_config
226
+ )
227
+
228
+ self.json_doc_status_storage = self.key_string_value_json_storage_cls(
229
+ namespace="json_doc_status_storage",
230
  embedding_func=None,
231
  )
232
 
233
+ self.llm_response_cache = self.key_string_value_json_storage_cls(
234
+ namespace="llm_response_cache",
235
+ embedding_func=None,
236
  )
237
 
238
  ####
 
240
  ####
241
  self.full_docs = self.key_string_value_json_storage_cls(
242
  namespace="full_docs",
 
243
  embedding_func=self.embedding_func,
244
  )
245
  self.text_chunks = self.key_string_value_json_storage_cls(
246
  namespace="text_chunks",
 
247
  embedding_func=self.embedding_func,
248
  )
249
  self.chunk_entity_relation_graph = self.graph_storage_cls(
250
  namespace="chunk_entity_relation",
 
251
  embedding_func=self.embedding_func,
252
  )
253
  ####
 
256
 
257
  self.entities_vdb = self.vector_db_storage_cls(
258
  namespace="entities",
 
259
  embedding_func=self.embedding_func,
260
  meta_fields={"entity_name"},
261
  )
262
  self.relationships_vdb = self.vector_db_storage_cls(
263
  namespace="relationships",
 
264
  embedding_func=self.embedding_func,
265
  meta_fields={"src_id", "tgt_id"},
266
  )
267
  self.chunks_vdb = self.vector_db_storage_cls(
268
  namespace="chunks",
 
269
  embedding_func=self.embedding_func,
270
  )
271
 
272
+ if self.llm_response_cache and hasattr(
273
+ self.llm_response_cache, "global_config"
274
+ ):
275
+ hashing_kv = self.llm_response_cache
276
+ else:
277
+ hashing_kv = self.key_string_value_json_storage_cls(
278
+ namespace="llm_response_cache",
279
+ embedding_func=None,
280
+ )
281
+
282
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
283
  partial(
284
  self.llm_model_func,
285
+ hashing_kv=hashing_kv,
 
 
 
 
 
 
 
286
  **self.llm_model_kwargs,
287
  )
288
  )
289
 
290
  # Initialize document status storage
291
+ self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
292
  self.doc_status = self.doc_status_storage_cls(
293
  namespace="doc_status",
294
+ global_config=global_config,
295
  embedding_func=None,
296
  )
297
 
298
+ def _get_storage_class(self, storage_name: str) -> dict:
299
+ import_path = STORAGES[storage_name]
300
+ storage_class = lazy_external_import(import_path, storage_name)
301
+ return storage_class
302
+
303
+ def set_storage_client(self, db_client):
304
+ # Now only tested on Oracle Database
305
+ for storage in [
306
+ self.vector_db_storage_cls,
307
+ self.graph_storage_cls,
308
+ self.doc_status,
309
+ self.full_docs,
310
+ self.text_chunks,
311
+ self.llm_response_cache,
312
+ self.key_string_value_json_storage_cls,
313
+ self.chunks_vdb,
314
+ self.relationships_vdb,
315
+ self.entities_vdb,
316
+ self.graph_storage_cls,
317
+ self.chunk_entity_relation_graph,
318
+ self.llm_response_cache,
319
+ ]:
320
+ # set client
321
+ storage.db = db_client
 
 
 
322
 
323
  def insert(
324
  self, string_or_strings, split_by_character=None, split_by_character_only=False
 
540
  if update_storage:
541
  await self._insert_done()
542
 
543
+ async def apipeline_process_documents(self, string_or_strings):
544
+ """Input list remove duplicates, generate document IDs and initial pendding status, filter out already stored documents, store docs
545
+ Args:
546
+ string_or_strings: Single document string or list of document strings
547
+ """
548
+ if isinstance(string_or_strings, str):
549
+ string_or_strings = [string_or_strings]
550
+
551
+ # 1. Remove duplicate contents from the list
552
+ unique_contents = list(set(doc.strip() for doc in string_or_strings))
553
+
554
+ logger.info(
555
+ f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents"
556
+ )
557
+
558
+ # 2. Generate document IDs and initial status
559
+ new_docs = {
560
+ compute_mdhash_id(content, prefix="doc-"): {
561
+ "content": content,
562
+ "content_summary": self._get_content_summary(content),
563
+ "content_length": len(content),
564
+ "status": DocStatus.PENDING,
565
+ "created_at": datetime.now().isoformat(),
566
+ "updated_at": None,
567
+ }
568
+ for content in unique_contents
569
+ }
570
+
571
+ # 3. Filter out already processed documents
572
+ _not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
573
+ if len(_not_stored_doc_keys) < len(new_docs):
574
+ logger.info(
575
+ f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents"
576
+ )
577
+ new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys}
578
+
579
+ if not new_docs:
580
+ logger.info("All documents have been processed or are duplicates")
581
+ return None
582
+
583
+ # 4. Store original document
584
+ for doc_id, doc in new_docs.items():
585
+ await self.full_docs.upsert({doc_id: {"content": doc["content"]}})
586
+ await self.full_docs.change_status(doc_id, DocStatus.PENDING)
587
+ logger.info(f"Stored {len(new_docs)} new unique documents")
588
+
589
+ async def apipeline_process_chunks(self):
590
+ """Get pendding documents, split into chunks,insert chunks"""
591
+ # 1. get all pending and failed documents
592
+ _todo_doc_keys = []
593
+ _failed_doc = await self.full_docs.get_by_status_and_ids(
594
+ status=DocStatus.FAILED, ids=None
595
+ )
596
+ _pendding_doc = await self.full_docs.get_by_status_and_ids(
597
+ status=DocStatus.PENDING, ids=None
598
+ )
599
+ if _failed_doc:
600
+ _todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
601
+ if _pendding_doc:
602
+ _todo_doc_keys.extend([doc["id"] for doc in _pendding_doc])
603
+ if not _todo_doc_keys:
604
+ logger.info("All documents have been processed or are duplicates")
605
+ return None
606
+ else:
607
+ logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents")
608
+
609
+ new_docs = {
610
+ doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys)
611
+ }
612
+
613
+ # 2. split docs into chunks, insert chunks, update doc status
614
+ chunk_cnt = 0
615
+ batch_size = self.addon_params.get("insert_batch_size", 10)
616
+ for i in range(0, len(new_docs), batch_size):
617
+ batch_docs = dict(list(new_docs.items())[i : i + batch_size])
618
+ for doc_id, doc in tqdm_async(
619
+ batch_docs.items(),
620
+ desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}",
621
+ ):
622
+ try:
623
+ # Generate chunks from document
624
+ chunks = {
625
+ compute_mdhash_id(dp["content"], prefix="chunk-"): {
626
+ **dp,
627
+ "full_doc_id": doc_id,
628
+ "status": DocStatus.PENDING,
629
+ }
630
+ for dp in chunking_by_token_size(
631
+ doc["content"],
632
+ overlap_token_size=self.chunk_overlap_token_size,
633
+ max_token_size=self.chunk_token_size,
634
+ tiktoken_model=self.tiktoken_model_name,
635
+ )
636
+ }
637
+ chunk_cnt += len(chunks)
638
+ await self.text_chunks.upsert(chunks)
639
+ await self.text_chunks.change_status(doc_id, DocStatus.PROCESSED)
640
+
641
+ try:
642
+ # Store chunks in vector database
643
+ await self.chunks_vdb.upsert(chunks)
644
+ # Update doc status
645
+ await self.full_docs.change_status(doc_id, DocStatus.PROCESSED)
646
+ except Exception as e:
647
+ # Mark as failed if any step fails
648
+ await self.full_docs.change_status(doc_id, DocStatus.FAILED)
649
+ raise e
650
+ except Exception as e:
651
+ import traceback
652
+
653
+ error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
654
+ logger.error(error_msg)
655
+ continue
656
+ logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents")
657
+
658
+ async def apipeline_process_extract_graph(self):
659
+ """Get pendding or failed chunks, extract entities and relationships from each chunk"""
660
+ # 1. get all pending and failed chunks
661
+ _todo_chunk_keys = []
662
+ _failed_chunks = await self.text_chunks.get_by_status_and_ids(
663
+ status=DocStatus.FAILED, ids=None
664
+ )
665
+ _pendding_chunks = await self.text_chunks.get_by_status_and_ids(
666
+ status=DocStatus.PENDING, ids=None
667
+ )
668
+ if _failed_chunks:
669
+ _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
670
+ if _pendding_chunks:
671
+ _todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks])
672
+ if not _todo_chunk_keys:
673
+ logger.info("All chunks have been processed or are duplicates")
674
+ return None
675
+
676
+ # Process documents in batches
677
+ batch_size = self.addon_params.get("insert_batch_size", 10)
678
+
679
+ semaphore = asyncio.Semaphore(
680
+ batch_size
681
+ ) # Control the number of tasks that are processed simultaneously
682
+
683
+ async def process_chunk(chunk_id):
684
+ async with semaphore:
685
+ chunks = {
686
+ i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
687
+ }
688
+ # Extract and store entities and relationships
689
+ try:
690
+ maybe_new_kg = await extract_entities(
691
+ chunks,
692
+ knowledge_graph_inst=self.chunk_entity_relation_graph,
693
+ entity_vdb=self.entities_vdb,
694
+ relationships_vdb=self.relationships_vdb,
695
+ llm_response_cache=self.llm_response_cache,
696
+ global_config=asdict(self),
697
+ )
698
+ if maybe_new_kg is None:
699
+ logger.info("No entities or relationships extracted!")
700
+ # Update status to processed
701
+ await self.text_chunks.change_status(chunk_id, DocStatus.PROCESSED)
702
+ except Exception as e:
703
+ logger.error("Failed to extract entities and relationships")
704
+ # Mark as failed if any step fails
705
+ await self.text_chunks.change_status(chunk_id, DocStatus.FAILED)
706
+ raise e
707
+
708
+ with tqdm_async(
709
+ total=len(_todo_chunk_keys),
710
+ desc="\nLevel 1 - Processing chunks",
711
+ unit="chunk",
712
+ position=0,
713
+ ) as progress:
714
+ tasks = []
715
+ for chunk_id in _todo_chunk_keys:
716
+ task = asyncio.create_task(process_chunk(chunk_id))
717
+ tasks.append(task)
718
+
719
+ for future in asyncio.as_completed(tasks):
720
+ await future
721
+ progress.update(1)
722
+ progress.set_postfix(
723
+ {
724
+ "LLM call": statistic_data["llm_call"],
725
+ "LLM cache": statistic_data["llm_cache"],
726
+ }
727
+ )
728
+
729
+ # Ensure all indexes are updated after each document
730
+ await self._insert_done()
731
+
732
  async def _insert_done(self):
733
  tasks = []
734
  for storage_inst in [
lightrag/operate.py CHANGED
@@ -20,6 +20,7 @@ from .utils import (
20
  handle_cache,
21
  save_to_cache,
22
  CacheData,
 
23
  )
24
  from .base import (
25
  BaseGraphStorage,
@@ -96,6 +97,10 @@ async def _handle_entity_relation_summary(
96
  description: str,
97
  global_config: dict,
98
  ) -> str:
 
 
 
 
99
  use_llm_func: callable = global_config["llm_model_func"]
100
  llm_max_tokens = global_config["llm_model_max_token_size"]
101
  tiktoken_model_name = global_config["tiktoken_model_name"]
@@ -176,6 +181,7 @@ async def _merge_nodes_then_upsert(
176
  knowledge_graph_inst: BaseGraphStorage,
177
  global_config: dict,
178
  ):
 
179
  already_entity_types = []
180
  already_source_ids = []
181
  already_description = []
@@ -356,7 +362,7 @@ async def extract_entities(
356
  llm_response_cache.global_config = new_config
357
  need_to_restore = True
358
  if history_messages:
359
- history = json.dumps(history_messages)
360
  _prompt = history + "\n" + input_text
361
  else:
362
  _prompt = input_text
@@ -368,8 +374,10 @@ async def extract_entities(
368
  if need_to_restore:
369
  llm_response_cache.global_config = global_config
370
  if cached_return:
 
 
371
  return cached_return
372
-
373
  if history_messages:
374
  res: str = await use_llm_func(
375
  input_text, history_messages=history_messages
@@ -388,6 +396,11 @@ async def extract_entities(
388
  return await use_llm_func(input_text)
389
 
390
  async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
 
 
 
 
 
391
  nonlocal already_processed, already_entities, already_relations
392
  chunk_key = chunk_key_dp[0]
393
  chunk_dp = chunk_key_dp[1]
@@ -451,10 +464,8 @@ async def extract_entities(
451
  now_ticks = PROMPTS["process_tickers"][
452
  already_processed % len(PROMPTS["process_tickers"])
453
  ]
454
- print(
455
  f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
456
- end="",
457
- flush=True,
458
  )
459
  return dict(maybe_nodes), dict(maybe_edges)
460
 
@@ -462,8 +473,10 @@ async def extract_entities(
462
  for result in tqdm_async(
463
  asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
464
  total=len(ordered_chunks),
465
- desc="Extracting entities from chunks",
466
  unit="chunk",
 
 
467
  ):
468
  results.append(await result)
469
 
@@ -474,7 +487,7 @@ async def extract_entities(
474
  maybe_nodes[k].extend(v)
475
  for k, v in m_edges.items():
476
  maybe_edges[tuple(sorted(k))].extend(v)
477
- logger.info("Inserting entities into storage...")
478
  all_entities_data = []
479
  for result in tqdm_async(
480
  asyncio.as_completed(
@@ -484,12 +497,14 @@ async def extract_entities(
484
  ]
485
  ),
486
  total=len(maybe_nodes),
487
- desc="Inserting entities",
488
  unit="entity",
 
 
489
  ):
490
  all_entities_data.append(await result)
491
 
492
- logger.info("Inserting relationships into storage...")
493
  all_relationships_data = []
494
  for result in tqdm_async(
495
  asyncio.as_completed(
@@ -501,8 +516,10 @@ async def extract_entities(
501
  ]
502
  ),
503
  total=len(maybe_edges),
504
- desc="Inserting relationships",
505
  unit="relationship",
 
 
506
  ):
507
  all_relationships_data.append(await result)
508
 
 
20
  handle_cache,
21
  save_to_cache,
22
  CacheData,
23
+ statistic_data,
24
  )
25
  from .base import (
26
  BaseGraphStorage,
 
97
  description: str,
98
  global_config: dict,
99
  ) -> str:
100
+ """Handle entity relation summary
101
+ For each entity or relation, input is the combined description of already existing description and new description.
102
+ If too long, use LLM to summarize.
103
+ """
104
  use_llm_func: callable = global_config["llm_model_func"]
105
  llm_max_tokens = global_config["llm_model_max_token_size"]
106
  tiktoken_model_name = global_config["tiktoken_model_name"]
 
181
  knowledge_graph_inst: BaseGraphStorage,
182
  global_config: dict,
183
  ):
184
+ """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
185
  already_entity_types = []
186
  already_source_ids = []
187
  already_description = []
 
362
  llm_response_cache.global_config = new_config
363
  need_to_restore = True
364
  if history_messages:
365
+ history = json.dumps(history_messages, ensure_ascii=False)
366
  _prompt = history + "\n" + input_text
367
  else:
368
  _prompt = input_text
 
374
  if need_to_restore:
375
  llm_response_cache.global_config = global_config
376
  if cached_return:
377
+ logger.debug(f"Found cache for {arg_hash}")
378
+ statistic_data["llm_cache"] += 1
379
  return cached_return
380
+ statistic_data["llm_call"] += 1
381
  if history_messages:
382
  res: str = await use_llm_func(
383
  input_text, history_messages=history_messages
 
396
  return await use_llm_func(input_text)
397
 
398
  async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
399
+ """ "Prpocess a single chunk
400
+ Args:
401
+ chunk_key_dp (tuple[str, TextChunkSchema]):
402
+ ("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
403
+ """
404
  nonlocal already_processed, already_entities, already_relations
405
  chunk_key = chunk_key_dp[0]
406
  chunk_dp = chunk_key_dp[1]
 
464
  now_ticks = PROMPTS["process_tickers"][
465
  already_processed % len(PROMPTS["process_tickers"])
466
  ]
467
+ logger.debug(
468
  f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
 
 
469
  )
470
  return dict(maybe_nodes), dict(maybe_edges)
471
 
 
473
  for result in tqdm_async(
474
  asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
475
  total=len(ordered_chunks),
476
+ desc="Level 2 - Extracting entities and relationships",
477
  unit="chunk",
478
+ position=1,
479
+ leave=False,
480
  ):
481
  results.append(await result)
482
 
 
487
  maybe_nodes[k].extend(v)
488
  for k, v in m_edges.items():
489
  maybe_edges[tuple(sorted(k))].extend(v)
490
+ logger.debug("Inserting entities into storage...")
491
  all_entities_data = []
492
  for result in tqdm_async(
493
  asyncio.as_completed(
 
497
  ]
498
  ),
499
  total=len(maybe_nodes),
500
+ desc="Level 3 - Inserting entities",
501
  unit="entity",
502
+ position=2,
503
+ leave=False,
504
  ):
505
  all_entities_data.append(await result)
506
 
507
+ logger.debug("Inserting relationships into storage...")
508
  all_relationships_data = []
509
  for result in tqdm_async(
510
  asyncio.as_completed(
 
516
  ]
517
  ),
518
  total=len(maybe_edges),
519
+ desc="Level 3 - Inserting relationships",
520
  unit="relationship",
521
+ position=3,
522
+ leave=False,
523
  ):
524
  all_relationships_data.append(await result)
525
 
lightrag/utils.py CHANGED
@@ -30,13 +30,18 @@ class UnlimitedSemaphore:
30
 
31
  ENCODER = None
32
 
 
 
33
  logger = logging.getLogger("lightrag")
34
 
 
 
 
35
 
36
  def set_logger(log_file: str):
37
  logger.setLevel(logging.DEBUG)
38
 
39
- file_handler = logging.FileHandler(log_file)
40
  file_handler.setLevel(logging.DEBUG)
41
 
42
  formatter = logging.Formatter(
@@ -453,7 +458,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
453
  return None, None, None, None
454
 
455
  # For naive mode, only use simple cache matching
456
- if mode == "naive":
 
457
  if exists_func(hashing_kv, "get_by_mode_and_id"):
458
  mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
459
  else:
@@ -473,7 +479,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
473
  quantized = min_val = max_val = None
474
  if is_embedding_cache_enabled:
475
  # Use embedding cache
476
- embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
 
 
477
  llm_model_func = hashing_kv.global_config.get("llm_model_func")
478
 
479
  current_embedding = await embedding_model_func([prompt])
 
30
 
31
  ENCODER = None
32
 
33
+ statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
34
+
35
  logger = logging.getLogger("lightrag")
36
 
37
+ # Set httpx logging level to WARNING
38
+ logging.getLogger("httpx").setLevel(logging.WARNING)
39
+
40
 
41
  def set_logger(log_file: str):
42
  logger.setLevel(logging.DEBUG)
43
 
44
+ file_handler = logging.FileHandler(log_file, encoding="utf-8")
45
  file_handler.setLevel(logging.DEBUG)
46
 
47
  formatter = logging.Formatter(
 
458
  return None, None, None, None
459
 
460
  # For naive mode, only use simple cache matching
461
+ # if mode == "naive":
462
+ if mode == "default":
463
  if exists_func(hashing_kv, "get_by_mode_and_id"):
464
  mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
465
  else:
 
479
  quantized = min_val = max_val = None
480
  if is_embedding_cache_enabled:
481
  # Use embedding cache
482
+ embedding_model_func = hashing_kv.global_config[
483
+ "embedding_func"
484
+ ].func # ["func"]
485
  llm_model_func = hashing_kv.global_config.get("llm_model_func")
486
 
487
  current_embedding = await embedding_model_func([prompt])