yangdx commited on
Commit
3bf5811
·
2 Parent(s): 4d227cd 41b651b

Merge branch 'main' into yangdx

Browse files
.gitignore CHANGED
@@ -21,4 +21,4 @@ rag_storage
21
  venv/
22
  examples/input/
23
  examples/output/
24
- test_results.json
 
21
  venv/
22
  examples/input/
23
  examples/output/
24
+ .DS_Store
README.md CHANGED
@@ -330,6 +330,26 @@ rag = LightRAG(
330
  with open("./newText.txt") as f:
331
  rag.insert(f.read())
332
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
  ### Using Neo4J for Storage
335
 
@@ -361,6 +381,7 @@ see test_neo4j.py for a working example.
361
  ### Using PostgreSQL for Storage
362
  For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE).
363
  * PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac.
 
364
  * How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
365
  * Create index for AGE example: (Change below `dickens` to your graph name if necessary)
366
  ```
 
330
  with open("./newText.txt") as f:
331
  rag.insert(f.read())
332
  ```
333
+ ### Separate Keyword Extraction
334
+ We've introduced a new function `query_with_separate_keyword_extraction` to enhance the keyword extraction capabilities. This function separates the keyword extraction process from the user's prompt, focusing solely on the query to improve the relevance of extracted keywords.
335
+
336
+ ##### How It Works?
337
+ The function operates by dividing the input into two parts:
338
+ - `User Query`
339
+ - `Prompt`
340
+
341
+ It then performs keyword extraction exclusively on the `user query`. This separation ensures that the extraction process is focused and relevant, unaffected by any additional language in the `prompt`. It also allows the `prompt` to serve purely for response formatting, maintaining the intent and clarity of the user's original question.
342
+
343
+ ##### Usage Example
344
+ This `example` shows how to tailor the function for educational content, focusing on detailed explanations for older students.
345
+
346
+ ```python
347
+ rag.query_with_separate_keyword_extraction(
348
+ query="Explain the law of gravity",
349
+ prompt="Provide a detailed explanation suitable for high school students studying physics.",
350
+ param=QueryParam(mode="hybrid")
351
+ )
352
+ ```
353
 
354
  ### Using Neo4J for Storage
355
 
 
381
  ### Using PostgreSQL for Storage
382
  For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE).
383
  * PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac.
384
+ * If you prefer docker, please start with this image if you are a beginner to avoid hiccups (DO read the overview): https://hub.docker.com/r/shangor/postgres-for-rag
385
  * How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
386
  * Create index for AGE example: (Change below `dickens` to your graph name if necessary)
387
  ```
examples/copy_llm_cache_to_another_storage.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sometimes you need to switch a storage solution, but you want to save LLM token and time.
3
+ This handy script helps you to copy the LLM caches from one storage solution to another.
4
+ (Not all the storage impl are supported)
5
+ """
6
+
7
+ import asyncio
8
+ import logging
9
+ import os
10
+ from dotenv import load_dotenv
11
+
12
+ from lightrag.kg.postgres_impl import PostgreSQLDB, PGKVStorage
13
+ from lightrag.storage import JsonKVStorage
14
+
15
+ load_dotenv()
16
+ ROOT_DIR = os.environ.get("ROOT_DIR")
17
+ WORKING_DIR = f"{ROOT_DIR}/dickens"
18
+
19
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
20
+
21
+ if not os.path.exists(WORKING_DIR):
22
+ os.mkdir(WORKING_DIR)
23
+
24
+ # AGE
25
+ os.environ["AGE_GRAPH_NAME"] = "chinese"
26
+
27
+ postgres_db = PostgreSQLDB(
28
+ config={
29
+ "host": "localhost",
30
+ "port": 15432,
31
+ "user": "rag",
32
+ "password": "rag",
33
+ "database": "r2",
34
+ }
35
+ )
36
+
37
+
38
+ async def copy_from_postgres_to_json():
39
+ await postgres_db.initdb()
40
+
41
+ from_llm_response_cache = PGKVStorage(
42
+ namespace="llm_response_cache",
43
+ global_config={"embedding_batch_num": 6},
44
+ embedding_func=None,
45
+ db=postgres_db,
46
+ )
47
+
48
+ to_llm_response_cache = JsonKVStorage(
49
+ namespace="llm_response_cache",
50
+ global_config={"working_dir": WORKING_DIR},
51
+ embedding_func=None,
52
+ )
53
+
54
+ kv = {}
55
+ for c_id in await from_llm_response_cache.all_keys():
56
+ print(f"Copying {c_id}")
57
+ workspace = c_id["workspace"]
58
+ mode = c_id["mode"]
59
+ _id = c_id["id"]
60
+ postgres_db.workspace = workspace
61
+ obj = await from_llm_response_cache.get_by_mode_and_id(mode, _id)
62
+ if mode not in kv:
63
+ kv[mode] = {}
64
+ kv[mode][_id] = obj[_id]
65
+ print(f"Object {obj}")
66
+ await to_llm_response_cache.upsert(kv)
67
+ await to_llm_response_cache.index_done_callback()
68
+ print("Mission accomplished!")
69
+
70
+
71
+ async def copy_from_json_to_postgres():
72
+ await postgres_db.initdb()
73
+
74
+ from_llm_response_cache = JsonKVStorage(
75
+ namespace="llm_response_cache",
76
+ global_config={"working_dir": WORKING_DIR},
77
+ embedding_func=None,
78
+ )
79
+
80
+ to_llm_response_cache = PGKVStorage(
81
+ namespace="llm_response_cache",
82
+ global_config={"embedding_batch_num": 6},
83
+ embedding_func=None,
84
+ db=postgres_db,
85
+ )
86
+
87
+ for mode in await from_llm_response_cache.all_keys():
88
+ print(f"Copying {mode}")
89
+ caches = await from_llm_response_cache.get_by_id(mode)
90
+ for k, v in caches.items():
91
+ item = {mode: {k: v}}
92
+ print(f"\tCopying {item}")
93
+ await to_llm_response_cache.upsert(item)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ asyncio.run(copy_from_json_to_postgres())
get_all_edges_nx.py → examples/get_all_edges_nx.py RENAMED
File without changes
examples/lightrag_oracle_demo.py CHANGED
@@ -20,7 +20,8 @@ BASE_URL = "http://xxx.xxx.xxx.xxx:8088/v1/"
20
  APIKEY = "ocigenerativeai"
21
  CHATMODEL = "cohere.command-r-plus"
22
  EMBEDMODEL = "cohere.embed-multilingual-v3.0"
23
-
 
24
 
25
  if not os.path.exists(WORKING_DIR):
26
  os.mkdir(WORKING_DIR)
@@ -86,30 +87,46 @@ async def main():
86
  # We use Oracle DB as the KV/vector/graph storage
87
  # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
88
  rag = LightRAG(
89
- enable_llm_cache=False,
90
  working_dir=WORKING_DIR,
91
- chunk_token_size=512,
 
 
 
 
 
92
  llm_model_func=llm_model_func,
93
  embedding_func=EmbeddingFunc(
94
  embedding_dim=embedding_dimension,
95
- max_token_size=512,
96
  func=embedding_func,
97
  ),
98
  graph_storage="OracleGraphStorage",
99
  kv_storage="OracleKVStorage",
100
  vector_storage="OracleVectorDBStorage",
 
 
 
 
 
 
101
  )
102
 
103
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
104
- rag.graph_storage_cls.db = oracle_db
105
- rag.key_string_value_json_storage_cls.db = oracle_db
106
- rag.vector_db_storage_cls.db = oracle_db
107
- # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
108
- rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
109
 
110
  # Extract and Insert into LightRAG storage
111
- with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
112
- await rag.ainsert(f.read())
 
 
 
 
 
 
 
 
 
113
 
114
  # Perform search in different modes
115
  modes = ["naive", "local", "global", "hybrid"]
 
20
  APIKEY = "ocigenerativeai"
21
  CHATMODEL = "cohere.command-r-plus"
22
  EMBEDMODEL = "cohere.embed-multilingual-v3.0"
23
+ CHUNK_TOKEN_SIZE = 1024
24
+ MAX_TOKENS = 4000
25
 
26
  if not os.path.exists(WORKING_DIR):
27
  os.mkdir(WORKING_DIR)
 
87
  # We use Oracle DB as the KV/vector/graph storage
88
  # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
89
  rag = LightRAG(
90
+ # log_level="DEBUG",
91
  working_dir=WORKING_DIR,
92
+ entity_extract_max_gleaning=1,
93
+ enable_llm_cache=True,
94
+ enable_llm_cache_for_entity_extract=True,
95
+ embedding_cache_config=None, # {"enabled": True,"similarity_threshold": 0.90},
96
+ chunk_token_size=CHUNK_TOKEN_SIZE,
97
+ llm_model_max_token_size=MAX_TOKENS,
98
  llm_model_func=llm_model_func,
99
  embedding_func=EmbeddingFunc(
100
  embedding_dim=embedding_dimension,
101
+ max_token_size=500,
102
  func=embedding_func,
103
  ),
104
  graph_storage="OracleGraphStorage",
105
  kv_storage="OracleKVStorage",
106
  vector_storage="OracleVectorDBStorage",
107
+ addon_params={
108
+ "example_number": 1,
109
+ "language": "Simplfied Chinese",
110
+ "entity_types": ["organization", "person", "geo", "event"],
111
+ "insert_batch_size": 2,
112
+ },
113
  )
114
 
115
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
116
+ rag.set_storage_client(db_client=oracle_db)
 
 
 
 
117
 
118
  # Extract and Insert into LightRAG storage
119
+ with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
120
+ all_text = f.read()
121
+ texts = [x for x in all_text.split("\n") if x]
122
+
123
+ # New mode use pipeline
124
+ await rag.apipeline_process_documents(texts)
125
+ await rag.apipeline_process_chunks()
126
+ await rag.apipeline_process_extract_graph()
127
+
128
+ # Old method use ainsert
129
+ # await rag.ainsert(texts)
130
 
131
  # Perform search in different modes
132
  modes = ["naive", "local", "global", "hybrid"]
examples/query_keyword_separation_example.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.utils import EmbeddingFunc
5
+ import numpy as np
6
+ from dotenv import load_dotenv
7
+ import logging
8
+ from openai import AzureOpenAI
9
+
10
+ logging.basicConfig(level=logging.INFO)
11
+
12
+ load_dotenv()
13
+
14
+ AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
15
+ AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
16
+ AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
17
+ AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
18
+
19
+ AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
20
+ AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION")
21
+
22
+ WORKING_DIR = "./dickens"
23
+
24
+ if os.path.exists(WORKING_DIR):
25
+ import shutil
26
+
27
+ shutil.rmtree(WORKING_DIR)
28
+
29
+ os.mkdir(WORKING_DIR)
30
+
31
+
32
+ async def llm_model_func(
33
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
34
+ ) -> str:
35
+ client = AzureOpenAI(
36
+ api_key=AZURE_OPENAI_API_KEY,
37
+ api_version=AZURE_OPENAI_API_VERSION,
38
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
39
+ )
40
+
41
+ messages = []
42
+ if system_prompt:
43
+ messages.append({"role": "system", "content": system_prompt})
44
+ if history_messages:
45
+ messages.extend(history_messages)
46
+ messages.append({"role": "user", "content": prompt})
47
+
48
+ chat_completion = client.chat.completions.create(
49
+ model=AZURE_OPENAI_DEPLOYMENT, # model = "deployment_name".
50
+ messages=messages,
51
+ temperature=kwargs.get("temperature", 0),
52
+ top_p=kwargs.get("top_p", 1),
53
+ n=kwargs.get("n", 1),
54
+ )
55
+ return chat_completion.choices[0].message.content
56
+
57
+
58
+ async def embedding_func(texts: list[str]) -> np.ndarray:
59
+ client = AzureOpenAI(
60
+ api_key=AZURE_OPENAI_API_KEY,
61
+ api_version=AZURE_EMBEDDING_API_VERSION,
62
+ azure_endpoint=AZURE_OPENAI_ENDPOINT,
63
+ )
64
+ embedding = client.embeddings.create(model=AZURE_EMBEDDING_DEPLOYMENT, input=texts)
65
+
66
+ embeddings = [item.embedding for item in embedding.data]
67
+ return np.array(embeddings)
68
+
69
+
70
+ async def test_funcs():
71
+ result = await llm_model_func("How are you?")
72
+ print("Resposta do llm_model_func: ", result)
73
+
74
+ result = await embedding_func(["How are you?"])
75
+ print("Resultado do embedding_func: ", result.shape)
76
+ print("Dimensão da embedding: ", result.shape[1])
77
+
78
+
79
+ asyncio.run(test_funcs())
80
+
81
+ embedding_dimension = 3072
82
+
83
+ rag = LightRAG(
84
+ working_dir=WORKING_DIR,
85
+ llm_model_func=llm_model_func,
86
+ embedding_func=EmbeddingFunc(
87
+ embedding_dim=embedding_dimension,
88
+ max_token_size=8192,
89
+ func=embedding_func,
90
+ ),
91
+ )
92
+
93
+ book1 = open("./book_1.txt", encoding="utf-8")
94
+ book2 = open("./book_2.txt", encoding="utf-8")
95
+
96
+ rag.insert([book1.read(), book2.read()])
97
+
98
+
99
+ # Example function demonstrating the new query_with_separate_keyword_extraction usage
100
+ async def run_example():
101
+ query = "What are the top themes in this story?"
102
+ prompt = "Please simplify the response for a young audience."
103
+
104
+ # Using the new method to ensure the keyword extraction is only applied to the query
105
+ response = rag.query_with_separate_keyword_extraction(
106
+ query=query,
107
+ prompt=prompt,
108
+ param=QueryParam(mode="hybrid"), # Adjust QueryParam mode as necessary
109
+ )
110
+
111
+ print("Extracted Response:", response)
112
+
113
+
114
+ # Run the example asynchronously
115
+ if __name__ == "__main__":
116
+ asyncio.run(run_example())
test.py → examples/test.py RENAMED
File without changes
test_chromadb.py → examples/test_chromadb.py RENAMED
File without changes
test_neo4j.py → examples/test_neo4j.py RENAMED
File without changes
lightrag/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
2
 
3
- __version__ = "1.1.1"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
 
1
  from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
2
 
3
+ __version__ = "1.1.2"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
lightrag/base.py CHANGED
@@ -31,6 +31,8 @@ class QueryParam:
31
  max_token_for_global_context: int = 4000
32
  # Number of tokens for the entity descriptions
33
  max_token_for_local_context: int = 4000
 
 
34
 
35
 
36
  @dataclass
 
31
  max_token_for_global_context: int = 4000
32
  # Number of tokens for the entity descriptions
33
  max_token_for_local_context: int = 4000
34
+ hl_keywords: list[str] = field(default_factory=list)
35
+ ll_keywords: list[str] = field(default_factory=list)
36
 
37
 
38
  @dataclass
lightrag/kg/oracle_impl.py CHANGED
@@ -153,8 +153,6 @@ class OracleDB:
153
  if data is None:
154
  await cursor.execute(sql)
155
  else:
156
- # print(data)
157
- # print(sql)
158
  await cursor.execute(sql, data)
159
  await connection.commit()
160
  except Exception as e:
@@ -167,35 +165,64 @@ class OracleDB:
167
  @dataclass
168
  class OracleKVStorage(BaseKVStorage):
169
  # should pass db object to self.db
 
 
 
170
  def __post_init__(self):
171
  self._data = {}
172
- self._max_batch_size = self.global_config["embedding_batch_num"]
173
 
174
  ################ QUERY METHODS ################
175
 
176
  async def get_by_id(self, id: str) -> Union[dict, None]:
177
- """根据 id 获取 doc_full 数据."""
178
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
179
  params = {"workspace": self.db.workspace, "id": id}
180
  # print("get_by_id:"+SQL)
181
- res = await self.db.query(SQL, params)
 
 
 
 
 
 
182
  if res:
183
- data = res # {"data":res}
184
- # print (data)
185
- return data
 
 
 
 
 
 
 
 
 
 
 
186
  else:
187
  return None
188
 
189
- # Query by id
190
  async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
191
- """根据 id 获取 doc_chunks 数据"""
192
  SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
193
  ids=",".join([f"'{id}'" for id in ids])
194
  )
195
  params = {"workspace": self.db.workspace}
196
  # print("get_by_ids:"+SQL)
197
- # print(params)
198
  res = await self.db.query(SQL, params, multirows=True)
 
 
 
 
 
 
 
 
 
 
 
199
  if res:
200
  data = res # [{"data":i} for i in res]
201
  # print(data)
@@ -203,38 +230,43 @@ class OracleKVStorage(BaseKVStorage):
203
  else:
204
  return None
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  async def filter_keys(self, keys: list[str]) -> set[str]:
207
- """过滤掉重复内容"""
208
  SQL = SQL_TEMPLATES["filter_keys"].format(
209
  table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
210
  )
211
  params = {"workspace": self.db.workspace}
212
- try:
213
- await self.db.query(SQL, params)
214
- except Exception as e:
215
- logger.error(f"Oracle database error: {e}")
216
- print(SQL)
217
- print(params)
218
  res = await self.db.query(SQL, params, multirows=True)
219
- data = None
220
  if res:
221
  exist_keys = [key["id"] for key in res]
222
  data = set([s for s in keys if s not in exist_keys])
 
223
  else:
224
- exist_keys = []
225
- data = set([s for s in keys if s not in exist_keys])
226
- return data
227
 
228
  ################ INSERT METHODS ################
229
  async def upsert(self, data: dict[str, dict]):
230
- left_data = {k: v for k, v in data.items() if k not in self._data}
231
- self._data.update(left_data)
232
- # print(self._data)
233
- # values = []
234
  if self.namespace == "text_chunks":
235
  list_data = [
236
  {
237
- "__id__": k,
238
  **{k1: v1 for k1, v1 in v.items()},
239
  }
240
  for k, v in data.items()
@@ -250,35 +282,50 @@ class OracleKVStorage(BaseKVStorage):
250
  embeddings = np.concatenate(embeddings_list)
251
  for i, d in enumerate(list_data):
252
  d["__vector__"] = embeddings[i]
253
- # print(list_data)
 
254
  for item in list_data:
255
- merge_sql = SQL_TEMPLATES["merge_chunk"]
256
- data = {
257
- "check_id": item["__id__"],
258
- "id": item["__id__"],
259
  "content": item["content"],
260
  "workspace": self.db.workspace,
261
  "tokens": item["tokens"],
262
  "chunk_order_index": item["chunk_order_index"],
263
  "full_doc_id": item["full_doc_id"],
264
  "content_vector": item["__vector__"],
 
265
  }
266
- # print(merge_sql)
267
- await self.db.execute(merge_sql, data)
268
-
269
  if self.namespace == "full_docs":
270
- for k, v in self._data.items():
271
  # values.clear()
272
  merge_sql = SQL_TEMPLATES["merge_doc_full"]
273
- data = {
274
- "check_id": k,
275
  "id": k,
276
  "content": v["content"],
277
  "workspace": self.db.workspace,
278
  }
279
- # print(merge_sql)
280
- await self.db.execute(merge_sql, data)
281
- return left_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
  async def index_done_callback(self):
284
  if self.namespace in ["full_docs", "text_chunks"]:
@@ -287,6 +334,8 @@ class OracleKVStorage(BaseKVStorage):
287
 
288
  @dataclass
289
  class OracleVectorDBStorage(BaseVectorStorage):
 
 
290
  cosine_better_than_threshold: float = 0.2
291
 
292
  def __post_init__(self):
@@ -328,7 +377,7 @@ class OracleGraphStorage(BaseGraphStorage):
328
 
329
  def __post_init__(self):
330
  """从graphml文件加载图"""
331
- self._max_batch_size = self.global_config["embedding_batch_num"]
332
 
333
  #################### insert method ################
334
 
@@ -362,7 +411,6 @@ class OracleGraphStorage(BaseGraphStorage):
362
  "content": content,
363
  "content_vector": content_vector,
364
  }
365
- # print(merge_sql)
366
  await self.db.execute(merge_sql, data)
367
  # self._graph.add_node(node_id, **node_data)
368
 
@@ -564,20 +612,26 @@ N_T = {
564
  TABLES = {
565
  "LIGHTRAG_DOC_FULL": {
566
  "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
567
- id varchar(256)PRIMARY KEY,
568
  workspace varchar(1024),
569
  doc_name varchar(1024),
570
  content CLOB,
571
  meta JSON,
 
 
 
 
572
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
573
- updatetime TIMESTAMP DEFAULT NULL
 
574
  )"""
575
  },
576
  "LIGHTRAG_DOC_CHUNKS": {
577
  "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
578
- id varchar(256) PRIMARY KEY,
579
  workspace varchar(1024),
580
  full_doc_id varchar(256),
 
581
  chunk_order_index NUMBER,
582
  tokens NUMBER,
583
  content CLOB,
@@ -619,9 +673,15 @@ TABLES = {
619
  "LIGHTRAG_LLM_CACHE": {
620
  "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
621
  id varchar(256) PRIMARY KEY,
622
- send clob,
623
- return clob,
624
- model varchar(1024),
 
 
 
 
 
 
625
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
626
  updatetime TIMESTAMP DEFAULT NULL
627
  )"""
@@ -646,23 +706,44 @@ TABLES = {
646
 
647
  SQL_TEMPLATES = {
648
  # SQL for KVStorage
649
- "get_by_id_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
650
- "get_by_id_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
651
- "get_by_ids_full_docs": "select ID,NVL(content,'') as content from LIGHTRAG_DOC_FULL where workspace=:workspace and ID in ({ids})",
652
- "get_by_ids_text_chunks": "select ID,TOKENS,NVL(content,'') as content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
 
 
 
 
 
 
 
 
 
 
653
  "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
654
- "merge_doc_full": """ MERGE INTO LIGHTRAG_DOC_FULL a
655
- USING DUAL
656
- ON (a.id = :check_id)
657
- WHEN NOT MATCHED THEN
658
- INSERT(id,content,workspace) values(:id,:content,:workspace)
659
- """,
660
- "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS a
661
- USING DUAL
662
- ON (a.id = :check_id)
663
- WHEN NOT MATCHED THEN
664
- INSERT(id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector)
665
- values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector) """,
 
 
 
 
 
 
 
 
 
 
 
666
  # SQL for VectorStorage
667
  "entities": """SELECT name as entity_name FROM
668
  (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
@@ -714,16 +795,22 @@ SQL_TEMPLATES = {
714
  COLUMNS (a.name as source_name,b.name as target_name))""",
715
  "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
716
  USING DUAL
717
- ON (a.workspace = :workspace and a.name=:name and a.source_chunk_id=:source_chunk_id)
718
  WHEN NOT MATCHED THEN
719
  INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
720
- values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector) """,
 
 
 
721
  "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
722
  USING DUAL
723
- ON (a.workspace = :workspace and a.source_name=:source_name and a.target_name=:target_name and a.source_chunk_id=:source_chunk_id)
724
  WHEN NOT MATCHED THEN
725
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
726
- values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector) """,
 
 
 
727
  "get_all_nodes": """WITH t0 AS (
728
  SELECT name AS id, entity_type AS label, entity_type, description,
729
  '["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
 
153
  if data is None:
154
  await cursor.execute(sql)
155
  else:
 
 
156
  await cursor.execute(sql, data)
157
  await connection.commit()
158
  except Exception as e:
 
165
  @dataclass
166
  class OracleKVStorage(BaseKVStorage):
167
  # should pass db object to self.db
168
+ db: OracleDB = None
169
+ meta_fields = None
170
+
171
  def __post_init__(self):
172
  self._data = {}
173
+ self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
174
 
175
  ################ QUERY METHODS ################
176
 
177
  async def get_by_id(self, id: str) -> Union[dict, None]:
178
+ """get doc_full data based on id."""
179
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
180
  params = {"workspace": self.db.workspace, "id": id}
181
  # print("get_by_id:"+SQL)
182
+ if "llm_response_cache" == self.namespace:
183
+ array_res = await self.db.query(SQL, params, multirows=True)
184
+ res = {}
185
+ for row in array_res:
186
+ res[row["id"]] = row
187
+ else:
188
+ res = await self.db.query(SQL, params)
189
  if res:
190
+ return res
191
+ else:
192
+ return None
193
+
194
+ async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
195
+ """Specifically for llm_response_cache."""
196
+ SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
197
+ params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
198
+ if "llm_response_cache" == self.namespace:
199
+ array_res = await self.db.query(SQL, params, multirows=True)
200
+ res = {}
201
+ for row in array_res:
202
+ res[row["id"]] = row
203
+ return res
204
  else:
205
  return None
206
 
 
207
  async def get_by_ids(self, ids: list[str], fields=None) -> Union[list[dict], None]:
208
+ """get doc_chunks data based on id"""
209
  SQL = SQL_TEMPLATES["get_by_ids_" + self.namespace].format(
210
  ids=",".join([f"'{id}'" for id in ids])
211
  )
212
  params = {"workspace": self.db.workspace}
213
  # print("get_by_ids:"+SQL)
 
214
  res = await self.db.query(SQL, params, multirows=True)
215
+ if "llm_response_cache" == self.namespace:
216
+ modes = set()
217
+ dict_res: dict[str, dict] = {}
218
+ for row in res:
219
+ modes.add(row["mode"])
220
+ for mode in modes:
221
+ if mode not in dict_res:
222
+ dict_res[mode] = {}
223
+ for row in res:
224
+ dict_res[row["mode"]][row["id"]] = row
225
+ res = [{k: v} for k, v in dict_res.items()]
226
  if res:
227
  data = res # [{"data":i} for i in res]
228
  # print(data)
 
230
  else:
231
  return None
232
 
233
+ async def get_by_status_and_ids(
234
+ self, status: str, ids: list[str]
235
+ ) -> Union[list[dict], None]:
236
+ """Specifically for llm_response_cache."""
237
+ if ids is not None:
238
+ SQL = SQL_TEMPLATES["get_by_status_ids_" + self.namespace].format(
239
+ ids=",".join([f"'{id}'" for id in ids])
240
+ )
241
+ else:
242
+ SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
243
+ params = {"workspace": self.db.workspace, "status": status}
244
+ res = await self.db.query(SQL, params, multirows=True)
245
+ if res:
246
+ return res
247
+ else:
248
+ return None
249
+
250
  async def filter_keys(self, keys: list[str]) -> set[str]:
251
+ """Return keys that don't exist in storage"""
252
  SQL = SQL_TEMPLATES["filter_keys"].format(
253
  table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
254
  )
255
  params = {"workspace": self.db.workspace}
 
 
 
 
 
 
256
  res = await self.db.query(SQL, params, multirows=True)
 
257
  if res:
258
  exist_keys = [key["id"] for key in res]
259
  data = set([s for s in keys if s not in exist_keys])
260
+ return data
261
  else:
262
+ return set(keys)
 
 
263
 
264
  ################ INSERT METHODS ################
265
  async def upsert(self, data: dict[str, dict]):
 
 
 
 
266
  if self.namespace == "text_chunks":
267
  list_data = [
268
  {
269
+ "id": k,
270
  **{k1: v1 for k1, v1 in v.items()},
271
  }
272
  for k, v in data.items()
 
282
  embeddings = np.concatenate(embeddings_list)
283
  for i, d in enumerate(list_data):
284
  d["__vector__"] = embeddings[i]
285
+
286
+ merge_sql = SQL_TEMPLATES["merge_chunk"]
287
  for item in list_data:
288
+ _data = {
289
+ "id": item["id"],
 
 
290
  "content": item["content"],
291
  "workspace": self.db.workspace,
292
  "tokens": item["tokens"],
293
  "chunk_order_index": item["chunk_order_index"],
294
  "full_doc_id": item["full_doc_id"],
295
  "content_vector": item["__vector__"],
296
+ "status": item["status"],
297
  }
298
+ await self.db.execute(merge_sql, _data)
 
 
299
  if self.namespace == "full_docs":
300
+ for k, v in data.items():
301
  # values.clear()
302
  merge_sql = SQL_TEMPLATES["merge_doc_full"]
303
+ _data = {
 
304
  "id": k,
305
  "content": v["content"],
306
  "workspace": self.db.workspace,
307
  }
308
+ await self.db.execute(merge_sql, _data)
309
+
310
+ if self.namespace == "llm_response_cache":
311
+ for mode, items in data.items():
312
+ for k, v in items.items():
313
+ upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
314
+ _data = {
315
+ "workspace": self.db.workspace,
316
+ "id": k,
317
+ "original_prompt": v["original_prompt"],
318
+ "return_value": v["return"],
319
+ "cache_mode": mode,
320
+ }
321
+
322
+ await self.db.execute(upsert_sql, _data)
323
+ return None
324
+
325
+ async def change_status(self, id: str, status: str):
326
+ SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace])
327
+ params = {"workspace": self.db.workspace, "id": id, "status": status}
328
+ await self.db.execute(SQL, params)
329
 
330
  async def index_done_callback(self):
331
  if self.namespace in ["full_docs", "text_chunks"]:
 
334
 
335
  @dataclass
336
  class OracleVectorDBStorage(BaseVectorStorage):
337
+ # should pass db object to self.db
338
+ db: OracleDB = None
339
  cosine_better_than_threshold: float = 0.2
340
 
341
  def __post_init__(self):
 
377
 
378
  def __post_init__(self):
379
  """从graphml文件加载图"""
380
+ self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
381
 
382
  #################### insert method ################
383
 
 
411
  "content": content,
412
  "content_vector": content_vector,
413
  }
 
414
  await self.db.execute(merge_sql, data)
415
  # self._graph.add_node(node_id, **node_data)
416
 
 
612
  TABLES = {
613
  "LIGHTRAG_DOC_FULL": {
614
  "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
615
+ id varchar(256),
616
  workspace varchar(1024),
617
  doc_name varchar(1024),
618
  content CLOB,
619
  meta JSON,
620
+ content_summary varchar(1024),
621
+ content_length NUMBER,
622
+ status varchar(256),
623
+ chunks_count NUMBER,
624
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
625
+ updatetime TIMESTAMP DEFAULT NULL,
626
+ error varchar(4096)
627
  )"""
628
  },
629
  "LIGHTRAG_DOC_CHUNKS": {
630
  "ddl": """CREATE TABLE LIGHTRAG_DOC_CHUNKS (
631
+ id varchar(256),
632
  workspace varchar(1024),
633
  full_doc_id varchar(256),
634
+ status varchar(256),
635
  chunk_order_index NUMBER,
636
  tokens NUMBER,
637
  content CLOB,
 
673
  "LIGHTRAG_LLM_CACHE": {
674
  "ddl": """CREATE TABLE LIGHTRAG_LLM_CACHE (
675
  id varchar(256) PRIMARY KEY,
676
+ workspace varchar(1024),
677
+ cache_mode varchar(256),
678
+ model_name varchar(256),
679
+ original_prompt clob,
680
+ return_value clob,
681
+ embedding CLOB,
682
+ embedding_shape NUMBER,
683
+ embedding_min NUMBER,
684
+ embedding_max NUMBER,
685
  createtime TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
686
  updatetime TIMESTAMP DEFAULT NULL
687
  )"""
 
706
 
707
  SQL_TEMPLATES = {
708
  # SQL for KVStorage
709
+ "get_by_id_full_docs": "select ID,content,status from LIGHTRAG_DOC_FULL where workspace=:workspace and ID=:id",
710
+ "get_by_id_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID=:id",
711
+ "get_by_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
712
+ FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id=:id""",
713
+ "get_by_mode_id_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
714
+ FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND cache_mode=:cache_mode AND id=:id""",
715
+ "get_by_ids_llm_response_cache": """SELECT id, original_prompt, NVL(return_value, '') as "return", cache_mode as "mode"
716
+ FROM LIGHTRAG_LLM_CACHE WHERE workspace=:workspace AND id IN ({ids})""",
717
+ "get_by_ids_full_docs": "select t.*,createtime as created_at from LIGHTRAG_DOC_FULL t where workspace=:workspace and ID in ({ids})",
718
+ "get_by_ids_text_chunks": "select ID,TOKENS,content,CHUNK_ORDER_INDEX,FULL_DOC_ID from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and ID in ({ids})",
719
+ "get_by_status_ids_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status and ID in ({ids})",
720
+ "get_by_status_ids_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status ID in ({ids})",
721
+ "get_by_status_full_docs": "select id,status from LIGHTRAG_DOC_FULL t where workspace=:workspace AND status=:status",
722
+ "get_by_status_text_chunks": "select id,status from LIGHTRAG_DOC_CHUNKS where workspace=:workspace and status=:status",
723
  "filter_keys": "select id from {table_name} where workspace=:workspace and id in ({ids})",
724
+ "change_status": "update {table_name} set status=:status,updatetime=SYSDATE where workspace=:workspace and id=:id",
725
+ "merge_doc_full": """MERGE INTO LIGHTRAG_DOC_FULL a
726
+ USING DUAL
727
+ ON (a.id = :id and a.workspace = :workspace)
728
+ WHEN NOT MATCHED THEN
729
+ INSERT(id,content,workspace) values(:id,:content,:workspace)""",
730
+ "merge_chunk": """MERGE INTO LIGHTRAG_DOC_CHUNKS
731
+ USING DUAL
732
+ ON (id = :id and workspace = :workspace)
733
+ WHEN NOT MATCHED THEN INSERT
734
+ (id,content,workspace,tokens,chunk_order_index,full_doc_id,content_vector,status)
735
+ values (:id,:content,:workspace,:tokens,:chunk_order_index,:full_doc_id,:content_vector,:status) """,
736
+ "upsert_llm_response_cache": """MERGE INTO LIGHTRAG_LLM_CACHE a
737
+ USING DUAL
738
+ ON (a.id = :id)
739
+ WHEN NOT MATCHED THEN
740
+ INSERT (workspace,id,original_prompt,return_value,cache_mode)
741
+ VALUES (:workspace,:id,:original_prompt,:return_value,:cache_mode)
742
+ WHEN MATCHED THEN UPDATE
743
+ SET original_prompt = :original_prompt,
744
+ return_value = :return_value,
745
+ cache_mode = :cache_mode,
746
+ updatetime = SYSDATE""",
747
  # SQL for VectorStorage
748
  "entities": """SELECT name as entity_name FROM
749
  (SELECT id,name,VECTOR_DISTANCE(content_vector,vector(:embedding_string,{dimension},{dtype}),COSINE) as distance
 
795
  COLUMNS (a.name as source_name,b.name as target_name))""",
796
  "merge_node": """MERGE INTO LIGHTRAG_GRAPH_NODES a
797
  USING DUAL
798
+ ON (a.workspace=:workspace and a.name=:name)
799
  WHEN NOT MATCHED THEN
800
  INSERT(workspace,name,entity_type,description,source_chunk_id,content,content_vector)
801
+ values (:workspace,:name,:entity_type,:description,:source_chunk_id,:content,:content_vector)
802
+ WHEN MATCHED THEN
803
+ UPDATE SET
804
+ entity_type=:entity_type,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
805
  "merge_edge": """MERGE INTO LIGHTRAG_GRAPH_EDGES a
806
  USING DUAL
807
+ ON (a.workspace=:workspace and a.source_name=:source_name and a.target_name=:target_name)
808
  WHEN NOT MATCHED THEN
809
  INSERT(workspace,source_name,target_name,weight,keywords,description,source_chunk_id,content,content_vector)
810
+ values (:workspace,:source_name,:target_name,:weight,:keywords,:description,:source_chunk_id,:content,:content_vector)
811
+ WHEN MATCHED THEN
812
+ UPDATE SET
813
+ weight=:weight,keywords=:keywords,description=:description,source_chunk_id=:source_chunk_id,content=:content,content_vector=:content_vector,updatetime=SYSDATE""",
814
  "get_all_nodes": """WITH t0 AS (
815
  SELECT name AS id, entity_type AS label, entity_type, description,
816
  '["' || replace(source_chunk_id, '<SEP>', '","') || '"]' source_chunk_ids
lightrag/kg/postgres_impl.py CHANGED
@@ -231,6 +231,16 @@ class PGKVStorage(BaseKVStorage):
231
  else:
232
  return None
233
 
 
 
 
 
 
 
 
 
 
 
234
  async def filter_keys(self, keys: List[str]) -> Set[str]:
235
  """Filter out duplicated content"""
236
  sql = SQL_TEMPLATES["filter_keys"].format(
@@ -412,7 +422,10 @@ class PGDocStatusStorage(DocStatusStorage):
412
 
413
  async def filter_keys(self, data: list[str]) -> set[str]:
414
  """Return keys that don't exist in storage"""
415
- sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})"
 
 
 
416
  result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
417
  # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
418
  if result is None:
 
231
  else:
232
  return None
233
 
234
+ async def all_keys(self) -> list[dict]:
235
+ if "llm_response_cache" == self.namespace:
236
+ sql = "select workspace,mode,id from lightrag_llm_cache"
237
+ res = await self.db.query(sql, multirows=True)
238
+ return res
239
+ else:
240
+ logger.error(
241
+ f"all_keys is only implemented for llm_response_cache, not for {self.namespace}"
242
+ )
243
+
244
  async def filter_keys(self, keys: List[str]) -> Set[str]:
245
  """Filter out duplicated content"""
246
  sql = SQL_TEMPLATES["filter_keys"].format(
 
422
 
423
  async def filter_keys(self, data: list[str]) -> set[str]:
424
  """Return keys that don't exist in storage"""
425
+ keys = ",".join([f"'{_id}'" for _id in data])
426
+ sql = (
427
+ f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({keys})"
428
+ )
429
  result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
430
  # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
431
  if result is None:
lightrag/lightrag.py CHANGED
@@ -17,6 +17,8 @@ from .operate import (
17
  kg_query,
18
  naive_query,
19
  mix_kg_vector_query,
 
 
20
  )
21
 
22
  from .utils import (
@@ -26,6 +28,7 @@ from .utils import (
26
  convert_response_to_json,
27
  logger,
28
  set_logger,
 
29
  )
30
  from .base import (
31
  BaseGraphStorage,
@@ -36,21 +39,30 @@ from .base import (
36
  DocStatus,
37
  )
38
 
39
- from .storage import (
40
- JsonKVStorage,
41
- NanoVectorDBStorage,
42
- NetworkXStorage,
43
- JsonDocStatusStorage,
44
- )
45
-
46
  from .prompt import GRAPH_FIELD_SEP
47
 
48
-
49
- # future KG integrations
50
-
51
- # from .kg.ArangoDB_impl import (
52
- # GraphStorage as ArangoDBStorage
53
- # )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
 
56
  def lazy_external_import(module_name: str, class_name: str):
@@ -66,34 +78,13 @@ def lazy_external_import(module_name: str, class_name: str):
66
  def import_class(*args, **kwargs):
67
  import importlib
68
 
69
- # Import the module using importlib
70
  module = importlib.import_module(module_name, package=package)
71
-
72
- # Get the class from the module and instantiate it
73
  cls = getattr(module, class_name)
74
  return cls(*args, **kwargs)
75
 
76
  return import_class
77
 
78
 
79
- Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage")
80
- OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage")
81
- OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage")
82
- OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
83
- MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
84
- MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
85
- ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
86
- TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
87
- TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
88
- TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage")
89
- PGKVStorage = lazy_external_import(".kg.postgres_impl", "PGKVStorage")
90
- PGVectorStorage = lazy_external_import(".kg.postgres_impl", "PGVectorStorage")
91
- AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
92
- PGGraphStorage = lazy_external_import(".kg.postgres_impl", "PGGraphStorage")
93
- GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage")
94
- PGDocStatusStorage = lazy_external_import(".kg.postgres_impl", "PGDocStatusStorage")
95
-
96
-
97
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
98
  """
99
  Ensure that there is always an event loop available.
@@ -197,34 +188,51 @@ class LightRAG:
197
  logger.setLevel(self.log_level)
198
 
199
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
 
 
 
200
 
201
- _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
 
 
202
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
203
 
204
- # @TODO: should move all storage setup here to leverage initial start params attached to self.
 
 
 
205
 
 
206
  self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
207
- self._get_storage_class()[self.kv_storage]
208
  )
209
- self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
210
  self.vector_storage
211
- ]
212
- self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
213
  self.graph_storage
214
- ]
215
 
216
- if not os.path.exists(self.working_dir):
217
- logger.info(f"Creating working directory {self.working_dir}")
218
- os.makedirs(self.working_dir)
219
 
220
- self.llm_response_cache = self.key_string_value_json_storage_cls(
221
- namespace="llm_response_cache",
222
- global_config=asdict(self),
 
 
 
 
 
 
 
223
  embedding_func=None,
224
  )
225
 
226
- self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
227
- self.embedding_func
 
228
  )
229
 
230
  ####
@@ -232,17 +240,14 @@ class LightRAG:
232
  ####
233
  self.full_docs = self.key_string_value_json_storage_cls(
234
  namespace="full_docs",
235
- global_config=asdict(self),
236
  embedding_func=self.embedding_func,
237
  )
238
  self.text_chunks = self.key_string_value_json_storage_cls(
239
  namespace="text_chunks",
240
- global_config=asdict(self),
241
  embedding_func=self.embedding_func,
242
  )
243
  self.chunk_entity_relation_graph = self.graph_storage_cls(
244
  namespace="chunk_entity_relation",
245
- global_config=asdict(self),
246
  embedding_func=self.embedding_func,
247
  )
248
  ####
@@ -251,72 +256,69 @@ class LightRAG:
251
 
252
  self.entities_vdb = self.vector_db_storage_cls(
253
  namespace="entities",
254
- global_config=asdict(self),
255
  embedding_func=self.embedding_func,
256
  meta_fields={"entity_name"},
257
  )
258
  self.relationships_vdb = self.vector_db_storage_cls(
259
  namespace="relationships",
260
- global_config=asdict(self),
261
  embedding_func=self.embedding_func,
262
  meta_fields={"src_id", "tgt_id"},
263
  )
264
  self.chunks_vdb = self.vector_db_storage_cls(
265
  namespace="chunks",
266
- global_config=asdict(self),
267
  embedding_func=self.embedding_func,
268
  )
269
 
 
 
 
 
 
 
 
 
 
 
270
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
271
  partial(
272
  self.llm_model_func,
273
- hashing_kv=self.llm_response_cache
274
- if self.llm_response_cache
275
- and hasattr(self.llm_response_cache, "global_config")
276
- else self.key_string_value_json_storage_cls(
277
- namespace="llm_response_cache",
278
- global_config=asdict(self),
279
- embedding_func=None,
280
- ),
281
  **self.llm_model_kwargs,
282
  )
283
  )
284
 
285
  # Initialize document status storage
286
- self.doc_status_storage_cls = self._get_storage_class()[self.doc_status_storage]
287
  self.doc_status = self.doc_status_storage_cls(
288
  namespace="doc_status",
289
- global_config=asdict(self),
290
  embedding_func=None,
291
  )
292
 
293
- def _get_storage_class(self) -> dict:
294
- return {
295
- # kv storage
296
- "JsonKVStorage": JsonKVStorage,
297
- "OracleKVStorage": OracleKVStorage,
298
- "MongoKVStorage": MongoKVStorage,
299
- "TiDBKVStorage": TiDBKVStorage,
300
- # vector storage
301
- "NanoVectorDBStorage": NanoVectorDBStorage,
302
- "OracleVectorDBStorage": OracleVectorDBStorage,
303
- "MilvusVectorDBStorge": MilvusVectorDBStorge,
304
- "ChromaVectorDBStorage": ChromaVectorDBStorage,
305
- "TiDBVectorDBStorage": TiDBVectorDBStorage,
306
- # graph storage
307
- "NetworkXStorage": NetworkXStorage,
308
- "Neo4JStorage": Neo4JStorage,
309
- "OracleGraphStorage": OracleGraphStorage,
310
- "AGEStorage": AGEStorage,
311
- "PGGraphStorage": PGGraphStorage,
312
- "PGKVStorage": PGKVStorage,
313
- "PGDocStatusStorage": PGDocStatusStorage,
314
- "PGVectorStorage": PGVectorStorage,
315
- "TiDBGraphStorage": TiDBGraphStorage,
316
- "GremlinStorage": GremlinStorage,
317
- # "ArangoDBStorage": ArangoDBStorage
318
- "JsonDocStatusStorage": JsonDocStatusStorage,
319
- }
320
 
321
  def insert(
322
  self, string_or_strings, split_by_character=None, split_by_character_only=False
@@ -538,6 +540,195 @@ class LightRAG:
538
  if update_storage:
539
  await self._insert_done()
540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
541
  async def _insert_done(self):
542
  tasks = []
543
  for storage_inst in [
@@ -753,6 +944,114 @@ class LightRAG:
753
  await self._query_done()
754
  return response
755
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
  async def _query_done(self):
757
  tasks = []
758
  for storage_inst in [self.llm_response_cache]:
 
17
  kg_query,
18
  naive_query,
19
  mix_kg_vector_query,
20
+ extract_keywords_only,
21
+ kg_query_with_keywords,
22
  )
23
 
24
  from .utils import (
 
28
  convert_response_to_json,
29
  logger,
30
  set_logger,
31
+ statistic_data,
32
  )
33
  from .base import (
34
  BaseGraphStorage,
 
39
  DocStatus,
40
  )
41
 
 
 
 
 
 
 
 
42
  from .prompt import GRAPH_FIELD_SEP
43
 
44
+ STORAGES = {
45
+ "JsonKVStorage": ".storage",
46
+ "NanoVectorDBStorage": ".storage",
47
+ "NetworkXStorage": ".storage",
48
+ "JsonDocStatusStorage": ".storage",
49
+ "Neo4JStorage": ".kg.neo4j_impl",
50
+ "OracleKVStorage": ".kg.oracle_impl",
51
+ "OracleGraphStorage": ".kg.oracle_impl",
52
+ "OracleVectorDBStorage": ".kg.oracle_impl",
53
+ "MilvusVectorDBStorge": ".kg.milvus_impl",
54
+ "MongoKVStorage": ".kg.mongo_impl",
55
+ "ChromaVectorDBStorage": ".kg.chroma_impl",
56
+ "TiDBKVStorage": ".kg.tidb_impl",
57
+ "TiDBVectorDBStorage": ".kg.tidb_impl",
58
+ "TiDBGraphStorage": ".kg.tidb_impl",
59
+ "PGKVStorage": ".kg.postgres_impl",
60
+ "PGVectorStorage": ".kg.postgres_impl",
61
+ "AGEStorage": ".kg.age_impl",
62
+ "PGGraphStorage": ".kg.postgres_impl",
63
+ "GremlinStorage": ".kg.gremlin_impl",
64
+ "PGDocStatusStorage": ".kg.postgres_impl",
65
+ }
66
 
67
 
68
  def lazy_external_import(module_name: str, class_name: str):
 
78
  def import_class(*args, **kwargs):
79
  import importlib
80
 
 
81
  module = importlib.import_module(module_name, package=package)
 
 
82
  cls = getattr(module, class_name)
83
  return cls(*args, **kwargs)
84
 
85
  return import_class
86
 
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
89
  """
90
  Ensure that there is always an event loop available.
 
188
  logger.setLevel(self.log_level)
189
 
190
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
191
+ if not os.path.exists(self.working_dir):
192
+ logger.info(f"Creating working directory {self.working_dir}")
193
+ os.makedirs(self.working_dir)
194
 
195
+ # show config
196
+ global_config = asdict(self)
197
+ _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
198
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
199
 
200
+ # Init LLM
201
+ self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
202
+ self.embedding_func
203
+ )
204
 
205
+ # Initialize all storages
206
  self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
207
+ self._get_storage_class(self.kv_storage)
208
  )
209
+ self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class(
210
  self.vector_storage
211
+ )
212
+ self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class(
213
  self.graph_storage
214
+ )
215
 
216
+ self.key_string_value_json_storage_cls = partial(
217
+ self.key_string_value_json_storage_cls, global_config=global_config
218
+ )
219
 
220
+ self.vector_db_storage_cls = partial(
221
+ self.vector_db_storage_cls, global_config=global_config
222
+ )
223
+
224
+ self.graph_storage_cls = partial(
225
+ self.graph_storage_cls, global_config=global_config
226
+ )
227
+
228
+ self.json_doc_status_storage = self.key_string_value_json_storage_cls(
229
+ namespace="json_doc_status_storage",
230
  embedding_func=None,
231
  )
232
 
233
+ self.llm_response_cache = self.key_string_value_json_storage_cls(
234
+ namespace="llm_response_cache",
235
+ embedding_func=None,
236
  )
237
 
238
  ####
 
240
  ####
241
  self.full_docs = self.key_string_value_json_storage_cls(
242
  namespace="full_docs",
 
243
  embedding_func=self.embedding_func,
244
  )
245
  self.text_chunks = self.key_string_value_json_storage_cls(
246
  namespace="text_chunks",
 
247
  embedding_func=self.embedding_func,
248
  )
249
  self.chunk_entity_relation_graph = self.graph_storage_cls(
250
  namespace="chunk_entity_relation",
 
251
  embedding_func=self.embedding_func,
252
  )
253
  ####
 
256
 
257
  self.entities_vdb = self.vector_db_storage_cls(
258
  namespace="entities",
 
259
  embedding_func=self.embedding_func,
260
  meta_fields={"entity_name"},
261
  )
262
  self.relationships_vdb = self.vector_db_storage_cls(
263
  namespace="relationships",
 
264
  embedding_func=self.embedding_func,
265
  meta_fields={"src_id", "tgt_id"},
266
  )
267
  self.chunks_vdb = self.vector_db_storage_cls(
268
  namespace="chunks",
 
269
  embedding_func=self.embedding_func,
270
  )
271
 
272
+ if self.llm_response_cache and hasattr(
273
+ self.llm_response_cache, "global_config"
274
+ ):
275
+ hashing_kv = self.llm_response_cache
276
+ else:
277
+ hashing_kv = self.key_string_value_json_storage_cls(
278
+ namespace="llm_response_cache",
279
+ embedding_func=None,
280
+ )
281
+
282
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
283
  partial(
284
  self.llm_model_func,
285
+ hashing_kv=hashing_kv,
 
 
 
 
 
 
 
286
  **self.llm_model_kwargs,
287
  )
288
  )
289
 
290
  # Initialize document status storage
291
+ self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
292
  self.doc_status = self.doc_status_storage_cls(
293
  namespace="doc_status",
294
+ global_config=global_config,
295
  embedding_func=None,
296
  )
297
 
298
+ def _get_storage_class(self, storage_name: str) -> dict:
299
+ import_path = STORAGES[storage_name]
300
+ storage_class = lazy_external_import(import_path, storage_name)
301
+ return storage_class
302
+
303
+ def set_storage_client(self, db_client):
304
+ # Now only tested on Oracle Database
305
+ for storage in [
306
+ self.vector_db_storage_cls,
307
+ self.graph_storage_cls,
308
+ self.doc_status,
309
+ self.full_docs,
310
+ self.text_chunks,
311
+ self.llm_response_cache,
312
+ self.key_string_value_json_storage_cls,
313
+ self.chunks_vdb,
314
+ self.relationships_vdb,
315
+ self.entities_vdb,
316
+ self.graph_storage_cls,
317
+ self.chunk_entity_relation_graph,
318
+ self.llm_response_cache,
319
+ ]:
320
+ # set client
321
+ storage.db = db_client
 
 
 
322
 
323
  def insert(
324
  self, string_or_strings, split_by_character=None, split_by_character_only=False
 
540
  if update_storage:
541
  await self._insert_done()
542
 
543
+ async def apipeline_process_documents(self, string_or_strings):
544
+ """Input list remove duplicates, generate document IDs and initial pendding status, filter out already stored documents, store docs
545
+ Args:
546
+ string_or_strings: Single document string or list of document strings
547
+ """
548
+ if isinstance(string_or_strings, str):
549
+ string_or_strings = [string_or_strings]
550
+
551
+ # 1. Remove duplicate contents from the list
552
+ unique_contents = list(set(doc.strip() for doc in string_or_strings))
553
+
554
+ logger.info(
555
+ f"Received {len(string_or_strings)} docs, contains {len(unique_contents)} new unique documents"
556
+ )
557
+
558
+ # 2. Generate document IDs and initial status
559
+ new_docs = {
560
+ compute_mdhash_id(content, prefix="doc-"): {
561
+ "content": content,
562
+ "content_summary": self._get_content_summary(content),
563
+ "content_length": len(content),
564
+ "status": DocStatus.PENDING,
565
+ "created_at": datetime.now().isoformat(),
566
+ "updated_at": None,
567
+ }
568
+ for content in unique_contents
569
+ }
570
+
571
+ # 3. Filter out already processed documents
572
+ _not_stored_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
573
+ if len(_not_stored_doc_keys) < len(new_docs):
574
+ logger.info(
575
+ f"Skipping {len(new_docs)-len(_not_stored_doc_keys)} already existing documents"
576
+ )
577
+ new_docs = {k: v for k, v in new_docs.items() if k in _not_stored_doc_keys}
578
+
579
+ if not new_docs:
580
+ logger.info("All documents have been processed or are duplicates")
581
+ return None
582
+
583
+ # 4. Store original document
584
+ for doc_id, doc in new_docs.items():
585
+ await self.full_docs.upsert({doc_id: {"content": doc["content"]}})
586
+ await self.full_docs.change_status(doc_id, DocStatus.PENDING)
587
+ logger.info(f"Stored {len(new_docs)} new unique documents")
588
+
589
+ async def apipeline_process_chunks(self):
590
+ """Get pendding documents, split into chunks,insert chunks"""
591
+ # 1. get all pending and failed documents
592
+ _todo_doc_keys = []
593
+ _failed_doc = await self.full_docs.get_by_status_and_ids(
594
+ status=DocStatus.FAILED, ids=None
595
+ )
596
+ _pendding_doc = await self.full_docs.get_by_status_and_ids(
597
+ status=DocStatus.PENDING, ids=None
598
+ )
599
+ if _failed_doc:
600
+ _todo_doc_keys.extend([doc["id"] for doc in _failed_doc])
601
+ if _pendding_doc:
602
+ _todo_doc_keys.extend([doc["id"] for doc in _pendding_doc])
603
+ if not _todo_doc_keys:
604
+ logger.info("All documents have been processed or are duplicates")
605
+ return None
606
+ else:
607
+ logger.info(f"Filtered out {len(_todo_doc_keys)} not processed documents")
608
+
609
+ new_docs = {
610
+ doc["id"]: doc for doc in await self.full_docs.get_by_ids(_todo_doc_keys)
611
+ }
612
+
613
+ # 2. split docs into chunks, insert chunks, update doc status
614
+ chunk_cnt = 0
615
+ batch_size = self.addon_params.get("insert_batch_size", 10)
616
+ for i in range(0, len(new_docs), batch_size):
617
+ batch_docs = dict(list(new_docs.items())[i : i + batch_size])
618
+ for doc_id, doc in tqdm_async(
619
+ batch_docs.items(),
620
+ desc=f"Level 1 - Spliting doc in batch {i//batch_size + 1}",
621
+ ):
622
+ try:
623
+ # Generate chunks from document
624
+ chunks = {
625
+ compute_mdhash_id(dp["content"], prefix="chunk-"): {
626
+ **dp,
627
+ "full_doc_id": doc_id,
628
+ "status": DocStatus.PENDING,
629
+ }
630
+ for dp in chunking_by_token_size(
631
+ doc["content"],
632
+ overlap_token_size=self.chunk_overlap_token_size,
633
+ max_token_size=self.chunk_token_size,
634
+ tiktoken_model=self.tiktoken_model_name,
635
+ )
636
+ }
637
+ chunk_cnt += len(chunks)
638
+ await self.text_chunks.upsert(chunks)
639
+ await self.text_chunks.change_status(doc_id, DocStatus.PROCESSED)
640
+
641
+ try:
642
+ # Store chunks in vector database
643
+ await self.chunks_vdb.upsert(chunks)
644
+ # Update doc status
645
+ await self.full_docs.change_status(doc_id, DocStatus.PROCESSED)
646
+ except Exception as e:
647
+ # Mark as failed if any step fails
648
+ await self.full_docs.change_status(doc_id, DocStatus.FAILED)
649
+ raise e
650
+ except Exception as e:
651
+ import traceback
652
+
653
+ error_msg = f"Failed to process document {doc_id}: {str(e)}\n{traceback.format_exc()}"
654
+ logger.error(error_msg)
655
+ continue
656
+ logger.info(f"Stored {chunk_cnt} chunks from {len(new_docs)} documents")
657
+
658
+ async def apipeline_process_extract_graph(self):
659
+ """Get pendding or failed chunks, extract entities and relationships from each chunk"""
660
+ # 1. get all pending and failed chunks
661
+ _todo_chunk_keys = []
662
+ _failed_chunks = await self.text_chunks.get_by_status_and_ids(
663
+ status=DocStatus.FAILED, ids=None
664
+ )
665
+ _pendding_chunks = await self.text_chunks.get_by_status_and_ids(
666
+ status=DocStatus.PENDING, ids=None
667
+ )
668
+ if _failed_chunks:
669
+ _todo_chunk_keys.extend([doc["id"] for doc in _failed_chunks])
670
+ if _pendding_chunks:
671
+ _todo_chunk_keys.extend([doc["id"] for doc in _pendding_chunks])
672
+ if not _todo_chunk_keys:
673
+ logger.info("All chunks have been processed or are duplicates")
674
+ return None
675
+
676
+ # Process documents in batches
677
+ batch_size = self.addon_params.get("insert_batch_size", 10)
678
+
679
+ semaphore = asyncio.Semaphore(
680
+ batch_size
681
+ ) # Control the number of tasks that are processed simultaneously
682
+
683
+ async def process_chunk(chunk_id):
684
+ async with semaphore:
685
+ chunks = {
686
+ i["id"]: i for i in await self.text_chunks.get_by_ids([chunk_id])
687
+ }
688
+ # Extract and store entities and relationships
689
+ try:
690
+ maybe_new_kg = await extract_entities(
691
+ chunks,
692
+ knowledge_graph_inst=self.chunk_entity_relation_graph,
693
+ entity_vdb=self.entities_vdb,
694
+ relationships_vdb=self.relationships_vdb,
695
+ llm_response_cache=self.llm_response_cache,
696
+ global_config=asdict(self),
697
+ )
698
+ if maybe_new_kg is None:
699
+ logger.info("No entities or relationships extracted!")
700
+ # Update status to processed
701
+ await self.text_chunks.change_status(chunk_id, DocStatus.PROCESSED)
702
+ except Exception as e:
703
+ logger.error("Failed to extract entities and relationships")
704
+ # Mark as failed if any step fails
705
+ await self.text_chunks.change_status(chunk_id, DocStatus.FAILED)
706
+ raise e
707
+
708
+ with tqdm_async(
709
+ total=len(_todo_chunk_keys),
710
+ desc="\nLevel 1 - Processing chunks",
711
+ unit="chunk",
712
+ position=0,
713
+ ) as progress:
714
+ tasks = []
715
+ for chunk_id in _todo_chunk_keys:
716
+ task = asyncio.create_task(process_chunk(chunk_id))
717
+ tasks.append(task)
718
+
719
+ for future in asyncio.as_completed(tasks):
720
+ await future
721
+ progress.update(1)
722
+ progress.set_postfix(
723
+ {
724
+ "LLM call": statistic_data["llm_call"],
725
+ "LLM cache": statistic_data["llm_cache"],
726
+ }
727
+ )
728
+
729
+ # Ensure all indexes are updated after each document
730
+ await self._insert_done()
731
+
732
  async def _insert_done(self):
733
  tasks = []
734
  for storage_inst in [
 
944
  await self._query_done()
945
  return response
946
 
947
+ def query_with_separate_keyword_extraction(
948
+ self, query: str, prompt: str, param: QueryParam = QueryParam()
949
+ ):
950
+ """
951
+ 1. Extract keywords from the 'query' using new function in operate.py.
952
+ 2. Then run the standard aquery() flow with the final prompt (formatted_question).
953
+ """
954
+
955
+ loop = always_get_an_event_loop()
956
+ return loop.run_until_complete(
957
+ self.aquery_with_separate_keyword_extraction(query, prompt, param)
958
+ )
959
+
960
+ async def aquery_with_separate_keyword_extraction(
961
+ self, query: str, prompt: str, param: QueryParam = QueryParam()
962
+ ):
963
+ """
964
+ 1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
965
+ 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
966
+ """
967
+
968
+ # ---------------------
969
+ # STEP 1: Keyword Extraction
970
+ # ---------------------
971
+ # We'll assume 'extract_keywords_only(...)' returns (hl_keywords, ll_keywords).
972
+ hl_keywords, ll_keywords = await extract_keywords_only(
973
+ text=query,
974
+ param=param,
975
+ global_config=asdict(self),
976
+ hashing_kv=self.llm_response_cache
977
+ or self.key_string_value_json_storage_cls(
978
+ namespace="llm_response_cache",
979
+ global_config=asdict(self),
980
+ embedding_func=None,
981
+ ),
982
+ )
983
+
984
+ param.hl_keywords = (hl_keywords,)
985
+ param.ll_keywords = (ll_keywords,)
986
+
987
+ # ---------------------
988
+ # STEP 2: Final Query Logic
989
+ # ---------------------
990
+
991
+ # Create a new string with the prompt and the keywords
992
+ ll_keywords_str = ", ".join(ll_keywords)
993
+ hl_keywords_str = ", ".join(hl_keywords)
994
+ formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
995
+
996
+ if param.mode in ["local", "global", "hybrid"]:
997
+ response = await kg_query_with_keywords(
998
+ formatted_question,
999
+ self.chunk_entity_relation_graph,
1000
+ self.entities_vdb,
1001
+ self.relationships_vdb,
1002
+ self.text_chunks,
1003
+ param,
1004
+ asdict(self),
1005
+ hashing_kv=self.llm_response_cache
1006
+ if self.llm_response_cache
1007
+ and hasattr(self.llm_response_cache, "global_config")
1008
+ else self.key_string_value_json_storage_cls(
1009
+ namespace="llm_response_cache",
1010
+ global_config=asdict(self),
1011
+ embedding_func=None,
1012
+ ),
1013
+ )
1014
+ elif param.mode == "naive":
1015
+ response = await naive_query(
1016
+ formatted_question,
1017
+ self.chunks_vdb,
1018
+ self.text_chunks,
1019
+ param,
1020
+ asdict(self),
1021
+ hashing_kv=self.llm_response_cache
1022
+ if self.llm_response_cache
1023
+ and hasattr(self.llm_response_cache, "global_config")
1024
+ else self.key_string_value_json_storage_cls(
1025
+ namespace="llm_response_cache",
1026
+ global_config=asdict(self),
1027
+ embedding_func=None,
1028
+ ),
1029
+ )
1030
+ elif param.mode == "mix":
1031
+ response = await mix_kg_vector_query(
1032
+ formatted_question,
1033
+ self.chunk_entity_relation_graph,
1034
+ self.entities_vdb,
1035
+ self.relationships_vdb,
1036
+ self.chunks_vdb,
1037
+ self.text_chunks,
1038
+ param,
1039
+ asdict(self),
1040
+ hashing_kv=self.llm_response_cache
1041
+ if self.llm_response_cache
1042
+ and hasattr(self.llm_response_cache, "global_config")
1043
+ else self.key_string_value_json_storage_cls(
1044
+ namespace="llm_response_cache",
1045
+ global_config=asdict(self),
1046
+ embedding_func=None,
1047
+ ),
1048
+ )
1049
+ else:
1050
+ raise ValueError(f"Unknown mode {param.mode}")
1051
+
1052
+ await self._query_done()
1053
+ return response
1054
+
1055
  async def _query_done(self):
1056
  tasks = []
1057
  for storage_inst in [self.llm_response_cache]:
lightrag/operate.py CHANGED
@@ -20,6 +20,7 @@ from .utils import (
20
  handle_cache,
21
  save_to_cache,
22
  CacheData,
 
23
  )
24
  from .base import (
25
  BaseGraphStorage,
@@ -96,6 +97,10 @@ async def _handle_entity_relation_summary(
96
  description: str,
97
  global_config: dict,
98
  ) -> str:
 
 
 
 
99
  use_llm_func: callable = global_config["llm_model_func"]
100
  llm_max_tokens = global_config["llm_model_max_token_size"]
101
  tiktoken_model_name = global_config["tiktoken_model_name"]
@@ -176,6 +181,7 @@ async def _merge_nodes_then_upsert(
176
  knowledge_graph_inst: BaseGraphStorage,
177
  global_config: dict,
178
  ):
 
179
  already_entity_types = []
180
  already_source_ids = []
181
  already_description = []
@@ -356,7 +362,7 @@ async def extract_entities(
356
  llm_response_cache.global_config = new_config
357
  need_to_restore = True
358
  if history_messages:
359
- history = json.dumps(history_messages)
360
  _prompt = history + "\n" + input_text
361
  else:
362
  _prompt = input_text
@@ -368,8 +374,10 @@ async def extract_entities(
368
  if need_to_restore:
369
  llm_response_cache.global_config = global_config
370
  if cached_return:
 
 
371
  return cached_return
372
-
373
  if history_messages:
374
  res: str = await use_llm_func(
375
  input_text, history_messages=history_messages
@@ -388,6 +396,11 @@ async def extract_entities(
388
  return await use_llm_func(input_text)
389
 
390
  async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
 
 
 
 
 
391
  nonlocal already_processed, already_entities, already_relations
392
  chunk_key = chunk_key_dp[0]
393
  chunk_dp = chunk_key_dp[1]
@@ -451,10 +464,8 @@ async def extract_entities(
451
  now_ticks = PROMPTS["process_tickers"][
452
  already_processed % len(PROMPTS["process_tickers"])
453
  ]
454
- print(
455
  f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
456
- end="",
457
- flush=True,
458
  )
459
  return dict(maybe_nodes), dict(maybe_edges)
460
 
@@ -462,8 +473,10 @@ async def extract_entities(
462
  for result in tqdm_async(
463
  asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
464
  total=len(ordered_chunks),
465
- desc="Extracting entities from chunks",
466
  unit="chunk",
 
 
467
  ):
468
  results.append(await result)
469
 
@@ -474,7 +487,7 @@ async def extract_entities(
474
  maybe_nodes[k].extend(v)
475
  for k, v in m_edges.items():
476
  maybe_edges[tuple(sorted(k))].extend(v)
477
- logger.info("Inserting entities into storage...")
478
  all_entities_data = []
479
  for result in tqdm_async(
480
  asyncio.as_completed(
@@ -484,12 +497,14 @@ async def extract_entities(
484
  ]
485
  ),
486
  total=len(maybe_nodes),
487
- desc="Inserting entities",
488
  unit="entity",
 
 
489
  ):
490
  all_entities_data.append(await result)
491
 
492
- logger.info("Inserting relationships into storage...")
493
  all_relationships_data = []
494
  for result in tqdm_async(
495
  asyncio.as_completed(
@@ -501,8 +516,10 @@ async def extract_entities(
501
  ]
502
  ),
503
  total=len(maybe_edges),
504
- desc="Inserting relationships",
505
  unit="relationship",
 
 
506
  ):
507
  all_relationships_data.append(await result)
508
 
@@ -681,6 +698,219 @@ async def kg_query(
681
  return response
682
 
683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
  async def _build_query_context(
685
  query: list,
686
  knowledge_graph_inst: BaseGraphStorage,
 
20
  handle_cache,
21
  save_to_cache,
22
  CacheData,
23
+ statistic_data,
24
  )
25
  from .base import (
26
  BaseGraphStorage,
 
97
  description: str,
98
  global_config: dict,
99
  ) -> str:
100
+ """Handle entity relation summary
101
+ For each entity or relation, input is the combined description of already existing description and new description.
102
+ If too long, use LLM to summarize.
103
+ """
104
  use_llm_func: callable = global_config["llm_model_func"]
105
  llm_max_tokens = global_config["llm_model_max_token_size"]
106
  tiktoken_model_name = global_config["tiktoken_model_name"]
 
181
  knowledge_graph_inst: BaseGraphStorage,
182
  global_config: dict,
183
  ):
184
+ """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
185
  already_entity_types = []
186
  already_source_ids = []
187
  already_description = []
 
362
  llm_response_cache.global_config = new_config
363
  need_to_restore = True
364
  if history_messages:
365
+ history = json.dumps(history_messages, ensure_ascii=False)
366
  _prompt = history + "\n" + input_text
367
  else:
368
  _prompt = input_text
 
374
  if need_to_restore:
375
  llm_response_cache.global_config = global_config
376
  if cached_return:
377
+ logger.debug(f"Found cache for {arg_hash}")
378
+ statistic_data["llm_cache"] += 1
379
  return cached_return
380
+ statistic_data["llm_call"] += 1
381
  if history_messages:
382
  res: str = await use_llm_func(
383
  input_text, history_messages=history_messages
 
396
  return await use_llm_func(input_text)
397
 
398
  async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
399
+ """ "Prpocess a single chunk
400
+ Args:
401
+ chunk_key_dp (tuple[str, TextChunkSchema]):
402
+ ("chunck-xxxxxx", {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int})
403
+ """
404
  nonlocal already_processed, already_entities, already_relations
405
  chunk_key = chunk_key_dp[0]
406
  chunk_dp = chunk_key_dp[1]
 
464
  now_ticks = PROMPTS["process_tickers"][
465
  already_processed % len(PROMPTS["process_tickers"])
466
  ]
467
+ logger.debug(
468
  f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
 
 
469
  )
470
  return dict(maybe_nodes), dict(maybe_edges)
471
 
 
473
  for result in tqdm_async(
474
  asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
475
  total=len(ordered_chunks),
476
+ desc="Level 2 - Extracting entities and relationships",
477
  unit="chunk",
478
+ position=1,
479
+ leave=False,
480
  ):
481
  results.append(await result)
482
 
 
487
  maybe_nodes[k].extend(v)
488
  for k, v in m_edges.items():
489
  maybe_edges[tuple(sorted(k))].extend(v)
490
+ logger.debug("Inserting entities into storage...")
491
  all_entities_data = []
492
  for result in tqdm_async(
493
  asyncio.as_completed(
 
497
  ]
498
  ),
499
  total=len(maybe_nodes),
500
+ desc="Level 3 - Inserting entities",
501
  unit="entity",
502
+ position=2,
503
+ leave=False,
504
  ):
505
  all_entities_data.append(await result)
506
 
507
+ logger.debug("Inserting relationships into storage...")
508
  all_relationships_data = []
509
  for result in tqdm_async(
510
  asyncio.as_completed(
 
516
  ]
517
  ),
518
  total=len(maybe_edges),
519
+ desc="Level 3 - Inserting relationships",
520
  unit="relationship",
521
+ position=3,
522
+ leave=False,
523
  ):
524
  all_relationships_data.append(await result)
525
 
 
698
  return response
699
 
700
 
701
+ async def kg_query_with_keywords(
702
+ query: str,
703
+ knowledge_graph_inst: BaseGraphStorage,
704
+ entities_vdb: BaseVectorStorage,
705
+ relationships_vdb: BaseVectorStorage,
706
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
707
+ query_param: QueryParam,
708
+ global_config: dict,
709
+ hashing_kv: BaseKVStorage = None,
710
+ ) -> str:
711
+ """
712
+ Refactored kg_query that does NOT extract keywords by itself.
713
+ It expects hl_keywords and ll_keywords to be set in query_param, or defaults to empty.
714
+ Then it uses those to build context and produce a final LLM response.
715
+ """
716
+
717
+ # ---------------------------
718
+ # 0) Handle potential cache
719
+ # ---------------------------
720
+ use_model_func = global_config["llm_model_func"]
721
+ args_hash = compute_args_hash(query_param.mode, query)
722
+ cached_response, quantized, min_val, max_val = await handle_cache(
723
+ hashing_kv, args_hash, query, query_param.mode
724
+ )
725
+ if cached_response is not None:
726
+ return cached_response
727
+
728
+ # ---------------------------
729
+ # 1) RETRIEVE KEYWORDS FROM query_param
730
+ # ---------------------------
731
+
732
+ # If these fields don't exist, default to empty lists/strings.
733
+ hl_keywords = getattr(query_param, "hl_keywords", []) or []
734
+ ll_keywords = getattr(query_param, "ll_keywords", []) or []
735
+
736
+ # If neither has any keywords, you could handle that logic here.
737
+ if not hl_keywords and not ll_keywords:
738
+ logger.warning(
739
+ "No keywords found in query_param. Could default to global mode or fail."
740
+ )
741
+ return PROMPTS["fail_response"]
742
+ if not ll_keywords and query_param.mode in ["local", "hybrid"]:
743
+ logger.warning("low_level_keywords is empty, switching to global mode.")
744
+ query_param.mode = "global"
745
+ if not hl_keywords and query_param.mode in ["global", "hybrid"]:
746
+ logger.warning("high_level_keywords is empty, switching to local mode.")
747
+ query_param.mode = "local"
748
+
749
+ # Flatten low-level and high-level keywords if needed
750
+ ll_keywords_flat = (
751
+ [item for sublist in ll_keywords for item in sublist]
752
+ if any(isinstance(i, list) for i in ll_keywords)
753
+ else ll_keywords
754
+ )
755
+ hl_keywords_flat = (
756
+ [item for sublist in hl_keywords for item in sublist]
757
+ if any(isinstance(i, list) for i in hl_keywords)
758
+ else hl_keywords
759
+ )
760
+
761
+ # Join the flattened lists
762
+ ll_keywords_str = ", ".join(ll_keywords_flat) if ll_keywords_flat else ""
763
+ hl_keywords_str = ", ".join(hl_keywords_flat) if hl_keywords_flat else ""
764
+
765
+ keywords = [ll_keywords_str, hl_keywords_str]
766
+
767
+ logger.info("Using %s mode for query processing", query_param.mode)
768
+
769
+ # ---------------------------
770
+ # 2) BUILD CONTEXT
771
+ # ---------------------------
772
+ context = await _build_query_context(
773
+ keywords,
774
+ knowledge_graph_inst,
775
+ entities_vdb,
776
+ relationships_vdb,
777
+ text_chunks_db,
778
+ query_param,
779
+ )
780
+ if not context:
781
+ return PROMPTS["fail_response"]
782
+
783
+ # If only context is needed, return it
784
+ if query_param.only_need_context:
785
+ return context
786
+
787
+ # ---------------------------
788
+ # 3) BUILD THE SYSTEM PROMPT + CALL LLM
789
+ # ---------------------------
790
+ sys_prompt_temp = PROMPTS["rag_response"]
791
+ sys_prompt = sys_prompt_temp.format(
792
+ context_data=context, response_type=query_param.response_type
793
+ )
794
+
795
+ if query_param.only_need_prompt:
796
+ return sys_prompt
797
+
798
+ # Now call the LLM with the final system prompt
799
+ response = await use_model_func(
800
+ query,
801
+ system_prompt=sys_prompt,
802
+ stream=query_param.stream,
803
+ )
804
+
805
+ # Clean up the response
806
+ if isinstance(response, str) and len(response) > len(sys_prompt):
807
+ response = (
808
+ response.replace(sys_prompt, "")
809
+ .replace("user", "")
810
+ .replace("model", "")
811
+ .replace(query, "")
812
+ .replace("<system>", "")
813
+ .replace("</system>", "")
814
+ .strip()
815
+ )
816
+
817
+ # ---------------------------
818
+ # 4) SAVE TO CACHE
819
+ # ---------------------------
820
+ await save_to_cache(
821
+ hashing_kv,
822
+ CacheData(
823
+ args_hash=args_hash,
824
+ content=response,
825
+ prompt=query,
826
+ quantized=quantized,
827
+ min_val=min_val,
828
+ max_val=max_val,
829
+ mode=query_param.mode,
830
+ ),
831
+ )
832
+ return response
833
+
834
+
835
+ async def extract_keywords_only(
836
+ text: str,
837
+ param: QueryParam,
838
+ global_config: dict,
839
+ hashing_kv: BaseKVStorage = None,
840
+ ) -> tuple[list[str], list[str]]:
841
+ """
842
+ Extract high-level and low-level keywords from the given 'text' using the LLM.
843
+ This method does NOT build the final RAG context or provide a final answer.
844
+ It ONLY extracts keywords (hl_keywords, ll_keywords).
845
+ """
846
+
847
+ # 1. Handle cache if needed
848
+ args_hash = compute_args_hash(param.mode, text)
849
+ cached_response, quantized, min_val, max_val = await handle_cache(
850
+ hashing_kv, args_hash, text, param.mode
851
+ )
852
+ if cached_response is not None:
853
+ # parse the cached_response if it’s JSON containing keywords
854
+ # or simply return (hl_keywords, ll_keywords) from cached
855
+ # Assuming cached_response is in the same JSON structure:
856
+ match = re.search(r"\{.*\}", cached_response, re.DOTALL)
857
+ if match:
858
+ keywords_data = json.loads(match.group(0))
859
+ hl_keywords = keywords_data.get("high_level_keywords", [])
860
+ ll_keywords = keywords_data.get("low_level_keywords", [])
861
+ return hl_keywords, ll_keywords
862
+ return [], []
863
+
864
+ # 2. Build the examples
865
+ example_number = global_config["addon_params"].get("example_number", None)
866
+ if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
867
+ examples = "\n".join(
868
+ PROMPTS["keywords_extraction_examples"][: int(example_number)]
869
+ )
870
+ else:
871
+ examples = "\n".join(PROMPTS["keywords_extraction_examples"])
872
+ language = global_config["addon_params"].get(
873
+ "language", PROMPTS["DEFAULT_LANGUAGE"]
874
+ )
875
+
876
+ # 3. Build the keyword-extraction prompt
877
+ kw_prompt_temp = PROMPTS["keywords_extraction"]
878
+ kw_prompt = kw_prompt_temp.format(query=text, examples=examples, language=language)
879
+
880
+ # 4. Call the LLM for keyword extraction
881
+ use_model_func = global_config["llm_model_func"]
882
+ result = await use_model_func(kw_prompt, keyword_extraction=True)
883
+
884
+ # 5. Parse out JSON from the LLM response
885
+ match = re.search(r"\{.*\}", result, re.DOTALL)
886
+ if not match:
887
+ logger.error("No JSON-like structure found in the result.")
888
+ return [], []
889
+ try:
890
+ keywords_data = json.loads(match.group(0))
891
+ except json.JSONDecodeError as e:
892
+ logger.error(f"JSON parsing error: {e}")
893
+ return [], []
894
+
895
+ hl_keywords = keywords_data.get("high_level_keywords", [])
896
+ ll_keywords = keywords_data.get("low_level_keywords", [])
897
+
898
+ # 6. Cache the result if needed
899
+ await save_to_cache(
900
+ hashing_kv,
901
+ CacheData(
902
+ args_hash=args_hash,
903
+ content=result,
904
+ prompt=text,
905
+ quantized=quantized,
906
+ min_val=min_val,
907
+ max_val=max_val,
908
+ mode=param.mode,
909
+ ),
910
+ )
911
+ return hl_keywords, ll_keywords
912
+
913
+
914
  async def _build_query_context(
915
  query: list,
916
  knowledge_graph_inst: BaseGraphStorage,
lightrag/utils.py CHANGED
@@ -30,13 +30,18 @@ class UnlimitedSemaphore:
30
 
31
  ENCODER = None
32
 
 
 
33
  logger = logging.getLogger("lightrag")
34
 
 
 
 
35
 
36
  def set_logger(log_file: str):
37
  logger.setLevel(logging.DEBUG)
38
 
39
- file_handler = logging.FileHandler(log_file)
40
  file_handler.setLevel(logging.DEBUG)
41
 
42
  formatter = logging.Formatter(
@@ -453,7 +458,8 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
453
  return None, None, None, None
454
 
455
  # For naive mode, only use simple cache matching
456
- if mode == "naive":
 
457
  if exists_func(hashing_kv, "get_by_mode_and_id"):
458
  mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
459
  else:
@@ -473,7 +479,9 @@ async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
473
  quantized = min_val = max_val = None
474
  if is_embedding_cache_enabled:
475
  # Use embedding cache
476
- embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
 
 
477
  llm_model_func = hashing_kv.global_config.get("llm_model_func")
478
 
479
  current_embedding = await embedding_model_func([prompt])
 
30
 
31
  ENCODER = None
32
 
33
+ statistic_data = {"llm_call": 0, "llm_cache": 0, "embed_call": 0}
34
+
35
  logger = logging.getLogger("lightrag")
36
 
37
+ # Set httpx logging level to WARNING
38
+ logging.getLogger("httpx").setLevel(logging.WARNING)
39
+
40
 
41
  def set_logger(log_file: str):
42
  logger.setLevel(logging.DEBUG)
43
 
44
+ file_handler = logging.FileHandler(log_file, encoding="utf-8")
45
  file_handler.setLevel(logging.DEBUG)
46
 
47
  formatter = logging.Formatter(
 
458
  return None, None, None, None
459
 
460
  # For naive mode, only use simple cache matching
461
+ # if mode == "naive":
462
+ if mode == "default":
463
  if exists_func(hashing_kv, "get_by_mode_and_id"):
464
  mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
465
  else:
 
479
  quantized = min_val = max_val = None
480
  if is_embedding_cache_enabled:
481
  # Use embedding cache
482
+ embedding_model_func = hashing_kv.global_config[
483
+ "embedding_func"
484
+ ].func # ["func"]
485
  llm_model_func = hashing_kv.global_config.get("llm_model_func")
486
 
487
  current_embedding = await embedding_model_func([prompt])