YanSte commited on
Commit
ead39ba
·
1 Parent(s): e8119d6

cleaning the message and project no needed

Browse files
README.md CHANGED
@@ -428,9 +428,9 @@ And using a routine to process news documents.
428
 
429
  ```python
430
  rag = LightRAG(..)
431
- await rag.apipeline_enqueue_documents(string_or_strings)
432
  # Your routine in loop
433
- await rag.apipeline_process_enqueue_documents(string_or_strings)
434
  ```
435
 
436
  ### Separate Keyword Extraction
 
428
 
429
  ```python
430
  rag = LightRAG(..)
431
+ await rag.apipeline_enqueue_documents(input)
432
  # Your routine in loop
433
+ await rag.apipeline_process_enqueue_documents(input)
434
  ```
435
 
436
  ### Separate Keyword Extraction
examples/lightrag_oracle_demo.py CHANGED
@@ -113,7 +113,24 @@ async def main():
113
  )
114
 
115
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
116
- rag.set_storage_client(db_client=oracle_db)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  # Extract and Insert into LightRAG storage
119
  with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
 
113
  )
114
 
115
  # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
116
+
117
+ for storage in [
118
+ rag.vector_db_storage_cls,
119
+ rag.graph_storage_cls,
120
+ rag.doc_status,
121
+ rag.full_docs,
122
+ rag.text_chunks,
123
+ rag.llm_response_cache,
124
+ rag.key_string_value_json_storage_cls,
125
+ rag.chunks_vdb,
126
+ rag.relationships_vdb,
127
+ rag.entities_vdb,
128
+ rag.graph_storage_cls,
129
+ rag.chunk_entity_relation_graph,
130
+ rag.llm_response_cache,
131
+ ]:
132
+ # set client
133
+ storage.db = oracle_db
134
 
135
  # Extract and Insert into LightRAG storage
136
  with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
external_bindings/OpenWebuiTool/openwebui_tool.py DELETED
@@ -1,358 +0,0 @@
1
- """
2
- OpenWebui Lightrag Integration Tool
3
- ==================================
4
-
5
- This tool enables the integration and use of Lightrag within the OpenWebui environment,
6
- providing a seamless interface for RAG (Retrieval-Augmented Generation) operations.
7
-
8
- Author: ParisNeo ([email protected])
9
- Social:
10
- - Twitter: @ParisNeo_AI
11
- - Reddit: r/lollms
12
- - Instagram: https://www.instagram.com/parisneo_ai/
13
-
14
- License: Apache 2.0
15
- Copyright (c) 2024-2025 ParisNeo
16
-
17
- This tool is part of the LoLLMs project (Lord of Large Language and Multimodal Systems).
18
- For more information, visit: https://github.com/ParisNeo/lollms
19
-
20
- Requirements:
21
- - Python 3.8+
22
- - OpenWebui
23
- - Lightrag
24
- """
25
-
26
- # Tool version
27
- __version__ = "1.0.0"
28
- __author__ = "ParisNeo"
29
- __author_email__ = "[email protected]"
30
- __description__ = "Lightrag integration for OpenWebui"
31
-
32
-
33
- import requests
34
- import json
35
- from pydantic import BaseModel, Field
36
- from typing import Callable, Any, Literal, Union, List, Tuple
37
-
38
-
39
- class StatusEventEmitter:
40
- def __init__(self, event_emitter: Callable[[dict], Any] = None):
41
- self.event_emitter = event_emitter
42
-
43
- async def emit(self, description="Unknown State", status="in_progress", done=False):
44
- if self.event_emitter:
45
- await self.event_emitter(
46
- {
47
- "type": "status",
48
- "data": {
49
- "status": status,
50
- "description": description,
51
- "done": done,
52
- },
53
- }
54
- )
55
-
56
-
57
- class MessageEventEmitter:
58
- def __init__(self, event_emitter: Callable[[dict], Any] = None):
59
- self.event_emitter = event_emitter
60
-
61
- async def emit(self, content="Some message"):
62
- if self.event_emitter:
63
- await self.event_emitter(
64
- {
65
- "type": "message",
66
- "data": {
67
- "content": content,
68
- },
69
- }
70
- )
71
-
72
-
73
- class Tools:
74
- class Valves(BaseModel):
75
- LIGHTRAG_SERVER_URL: str = Field(
76
- default="http://localhost:9621/query",
77
- description="The base URL for the LightRag server",
78
- )
79
- MODE: Literal["naive", "local", "global", "hybrid"] = Field(
80
- default="hybrid",
81
- description="The mode to use for the LightRag query. Options: naive, local, global, hybrid",
82
- )
83
- ONLY_NEED_CONTEXT: bool = Field(
84
- default=False,
85
- description="If True, only the context is needed from the LightRag response",
86
- )
87
- DEBUG_MODE: bool = Field(
88
- default=False,
89
- description="If True, debugging information will be emitted",
90
- )
91
- KEY: str = Field(
92
- default="",
93
- description="Optional Bearer Key for authentication",
94
- )
95
- MAX_ENTITIES: int = Field(
96
- default=5,
97
- description="Maximum number of entities to keep",
98
- )
99
- MAX_RELATIONSHIPS: int = Field(
100
- default=5,
101
- description="Maximum number of relationships to keep",
102
- )
103
- MAX_SOURCES: int = Field(
104
- default=3,
105
- description="Maximum number of sources to keep",
106
- )
107
-
108
- def __init__(self):
109
- self.valves = self.Valves()
110
- self.headers = {
111
- "Content-Type": "application/json",
112
- "User-Agent": "LightRag-Tool/1.0",
113
- }
114
-
115
- async def query_lightrag(
116
- self,
117
- query: str,
118
- __event_emitter__: Callable[[dict], Any] = None,
119
- ) -> str:
120
- """
121
- Query the LightRag server and retrieve information.
122
- This function must be called before answering the user question
123
- :params query: The query string to send to the LightRag server.
124
- :return: The response from the LightRag server in Markdown format or raw response.
125
- """
126
- self.status_emitter = StatusEventEmitter(__event_emitter__)
127
- self.message_emitter = MessageEventEmitter(__event_emitter__)
128
-
129
- lightrag_url = self.valves.LIGHTRAG_SERVER_URL
130
- payload = {
131
- "query": query,
132
- "mode": str(self.valves.MODE),
133
- "stream": False,
134
- "only_need_context": self.valves.ONLY_NEED_CONTEXT,
135
- }
136
- await self.status_emitter.emit("Initializing Lightrag query..")
137
-
138
- if self.valves.DEBUG_MODE:
139
- await self.message_emitter.emit(
140
- "### Debug Mode Active\n\nDebugging information will be displayed.\n"
141
- )
142
- await self.message_emitter.emit(
143
- "#### Payload Sent to LightRag Server\n```json\n"
144
- + json.dumps(payload, indent=4)
145
- + "\n```\n"
146
- )
147
-
148
- # Add Bearer Key to headers if provided
149
- if self.valves.KEY:
150
- self.headers["Authorization"] = f"Bearer {self.valves.KEY}"
151
-
152
- try:
153
- await self.status_emitter.emit("Sending request to LightRag server")
154
-
155
- response = requests.post(
156
- lightrag_url, json=payload, headers=self.headers, timeout=120
157
- )
158
- response.raise_for_status()
159
- data = response.json()
160
- await self.status_emitter.emit(
161
- status="complete",
162
- description="LightRag query Succeeded",
163
- done=True,
164
- )
165
-
166
- # Return parsed Markdown if ONLY_NEED_CONTEXT is True, otherwise return raw response
167
- if self.valves.ONLY_NEED_CONTEXT:
168
- try:
169
- if self.valves.DEBUG_MODE:
170
- await self.message_emitter.emit(
171
- "#### LightRag Server Response\n```json\n"
172
- + data["response"]
173
- + "\n```\n"
174
- )
175
- except Exception as ex:
176
- if self.valves.DEBUG_MODE:
177
- await self.message_emitter.emit(
178
- "#### Exception\n" + str(ex) + "\n"
179
- )
180
- return f"Exception: {ex}"
181
- return data["response"]
182
- else:
183
- if self.valves.DEBUG_MODE:
184
- await self.message_emitter.emit(
185
- "#### LightRag Server Response\n```json\n"
186
- + data["response"]
187
- + "\n```\n"
188
- )
189
- await self.status_emitter.emit("Lightrag query success")
190
- return data["response"]
191
-
192
- except requests.exceptions.RequestException as e:
193
- await self.status_emitter.emit(
194
- status="error",
195
- description=f"Error during LightRag query: {str(e)}",
196
- done=True,
197
- )
198
- return json.dumps({"error": str(e)})
199
-
200
- def extract_code_blocks(
201
- self, text: str, return_remaining_text: bool = False
202
- ) -> Union[List[dict], Tuple[List[dict], str]]:
203
- """
204
- This function extracts code blocks from a given text and optionally returns the text without code blocks.
205
-
206
- Parameters:
207
- text (str): The text from which to extract code blocks. Code blocks are identified by triple backticks (```).
208
- return_remaining_text (bool): If True, also returns the text with code blocks removed.
209
-
210
- Returns:
211
- Union[List[dict], Tuple[List[dict], str]]:
212
- - If return_remaining_text is False: Returns only the list of code block dictionaries
213
- - If return_remaining_text is True: Returns a tuple containing:
214
- * List of code block dictionaries
215
- * String containing the text with all code blocks removed
216
-
217
- Each code block dictionary contains:
218
- - 'index' (int): The index of the code block in the text
219
- - 'file_name' (str): The name of the file extracted from the preceding line, if available
220
- - 'content' (str): The content of the code block
221
- - 'type' (str): The type of the code block
222
- - 'is_complete' (bool): True if the block has a closing tag, False otherwise
223
- """
224
- remaining = text
225
- bloc_index = 0
226
- first_index = 0
227
- indices = []
228
- text_without_blocks = text
229
-
230
- # Find all code block delimiters
231
- while len(remaining) > 0:
232
- try:
233
- index = remaining.index("```")
234
- indices.append(index + first_index)
235
- remaining = remaining[index + 3 :]
236
- first_index += index + 3
237
- bloc_index += 1
238
- except Exception:
239
- if bloc_index % 2 == 1:
240
- index = len(remaining)
241
- indices.append(index)
242
- remaining = ""
243
-
244
- code_blocks = []
245
- is_start = True
246
-
247
- # Process code blocks and build text without blocks if requested
248
- if return_remaining_text:
249
- text_parts = []
250
- last_end = 0
251
-
252
- for index, code_delimiter_position in enumerate(indices):
253
- if is_start:
254
- block_infos = {
255
- "index": len(code_blocks),
256
- "file_name": "",
257
- "section": "",
258
- "content": "",
259
- "type": "",
260
- "is_complete": False,
261
- }
262
-
263
- # Store text before code block if returning remaining text
264
- if return_remaining_text:
265
- text_parts.append(text[last_end:code_delimiter_position].strip())
266
-
267
- # Check the preceding line for file name
268
- preceding_text = text[:code_delimiter_position].strip().splitlines()
269
- if preceding_text:
270
- last_line = preceding_text[-1].strip()
271
- if last_line.startswith("<file_name>") and last_line.endswith(
272
- "</file_name>"
273
- ):
274
- file_name = last_line[
275
- len("<file_name>") : -len("</file_name>")
276
- ].strip()
277
- block_infos["file_name"] = file_name
278
- elif last_line.startswith("## filename:"):
279
- file_name = last_line[len("## filename:") :].strip()
280
- block_infos["file_name"] = file_name
281
- if last_line.startswith("<section>") and last_line.endswith(
282
- "</section>"
283
- ):
284
- section = last_line[
285
- len("<section>") : -len("</section>")
286
- ].strip()
287
- block_infos["section"] = section
288
-
289
- sub_text = text[code_delimiter_position + 3 :]
290
- if len(sub_text) > 0:
291
- try:
292
- find_space = sub_text.index(" ")
293
- except Exception:
294
- find_space = int(1e10)
295
- try:
296
- find_return = sub_text.index("\n")
297
- except Exception:
298
- find_return = int(1e10)
299
- next_index = min(find_return, find_space)
300
- if "{" in sub_text[:next_index]:
301
- next_index = 0
302
- start_pos = next_index
303
-
304
- if code_delimiter_position + 3 < len(text) and text[
305
- code_delimiter_position + 3
306
- ] in ["\n", " ", "\t"]:
307
- block_infos["type"] = "language-specific"
308
- else:
309
- block_infos["type"] = sub_text[:next_index]
310
-
311
- if index + 1 < len(indices):
312
- next_pos = indices[index + 1] - code_delimiter_position
313
- if (
314
- next_pos - 3 < len(sub_text)
315
- and sub_text[next_pos - 3] == "`"
316
- ):
317
- block_infos["content"] = sub_text[
318
- start_pos : next_pos - 3
319
- ].strip()
320
- block_infos["is_complete"] = True
321
- else:
322
- block_infos["content"] = sub_text[
323
- start_pos:next_pos
324
- ].strip()
325
- block_infos["is_complete"] = False
326
-
327
- if return_remaining_text:
328
- last_end = indices[index + 1] + 3
329
- else:
330
- block_infos["content"] = sub_text[start_pos:].strip()
331
- block_infos["is_complete"] = False
332
-
333
- if return_remaining_text:
334
- last_end = len(text)
335
-
336
- code_blocks.append(block_infos)
337
- is_start = False
338
- else:
339
- is_start = True
340
-
341
- if return_remaining_text:
342
- # Add any remaining text after the last code block
343
- if last_end < len(text):
344
- text_parts.append(text[last_end:].strip())
345
- # Join all non-code parts with newlines
346
- text_without_blocks = "\n".join(filter(None, text_parts))
347
- return code_blocks, text_without_blocks
348
-
349
- return code_blocks
350
-
351
- def clean(self, csv_content: str):
352
- lines = csv_content.splitlines()
353
- if lines:
354
- # Remove spaces around headers and ensure no spaces between commas
355
- header = ",".join([col.strip() for col in lines[0].split(",")])
356
- lines[0] = header # Replace the first line with the cleaned header
357
- csv_content = "\n".join(lines)
358
- return csv_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightrag/base.py CHANGED
@@ -83,11 +83,11 @@ class StorageNameSpace:
83
  namespace: str
84
  global_config: dict[str, Any]
85
 
86
- async def index_done_callback(self):
87
  """Commit the storage operations after indexing"""
88
  pass
89
 
90
- async def query_done_callback(self):
91
  """Commit the storage operations after querying"""
92
  pass
93
 
 
83
  namespace: str
84
  global_config: dict[str, Any]
85
 
86
+ async def index_done_callback(self) -> None:
87
  """Commit the storage operations after indexing"""
88
  pass
89
 
90
+ async def query_done_callback(self) -> None:
91
  """Commit the storage operations after querying"""
92
  pass
93
 
lightrag/lightrag.py CHANGED
@@ -6,7 +6,7 @@ import configparser
6
  from dataclasses import asdict, dataclass, field
7
  from datetime import datetime
8
  from functools import partial
9
- from typing import Any, Callable, Optional, Type, Union, cast
10
 
11
  from .base import (
12
  BaseGraphStorage,
@@ -304,7 +304,7 @@ class LightRAG:
304
  - random_seed: Seed value for reproducibility.
305
  """
306
 
307
- embedding_func: Union[EmbeddingFunc, None] = None
308
  """Function for computing text embeddings. Must be set before use."""
309
 
310
  embedding_batch_num: int = 32
@@ -344,10 +344,8 @@ class LightRAG:
344
 
345
  # Extensions
346
  addon_params: dict[str, Any] = field(default_factory=dict)
347
- """Dictionary for additional parameters and extensions."""
348
 
349
- # extension
350
- addon_params: dict[str, Any] = field(default_factory=dict)
351
  convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
352
  convert_response_to_json
353
  )
@@ -445,77 +443,74 @@ class LightRAG:
445
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
446
 
447
  # Init LLM
448
- self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
449
  self.embedding_func
450
  )
451
 
452
  # Initialize all storages
453
- self.key_string_value_json_storage_cls: Type[BaseKVStorage] = ( # type: ignore
454
  self._get_storage_class(self.kv_storage)
455
- )
456
- self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class( # type: ignore
457
  self.vector_storage
458
- )
459
- self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class( # type: ignore
460
  self.graph_storage
461
- )
462
-
463
- self.key_string_value_json_storage_cls = partial( # type: ignore
464
  self.key_string_value_json_storage_cls, global_config=global_config
465
  )
466
-
467
- self.vector_db_storage_cls = partial( # type: ignore
468
  self.vector_db_storage_cls, global_config=global_config
469
  )
470
-
471
- self.graph_storage_cls = partial( # type: ignore
472
  self.graph_storage_cls, global_config=global_config
473
  )
474
 
475
  # Initialize document status storage
476
  self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
477
 
478
- self.llm_response_cache = self.key_string_value_json_storage_cls( # type: ignore
479
  namespace=make_namespace(
480
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
481
  ),
482
  embedding_func=self.embedding_func,
483
  )
484
 
485
- self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
486
  namespace=make_namespace(
487
  self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
488
  ),
489
  embedding_func=self.embedding_func,
490
  )
491
- self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
492
  namespace=make_namespace(
493
  self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
494
  ),
495
  embedding_func=self.embedding_func,
496
  )
497
- self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
498
  namespace=make_namespace(
499
  self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
500
  ),
501
  embedding_func=self.embedding_func,
502
  )
503
 
504
- self.entities_vdb = self.vector_db_storage_cls( # type: ignore
505
  namespace=make_namespace(
506
  self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
507
  ),
508
  embedding_func=self.embedding_func,
509
  meta_fields={"entity_name"},
510
  )
511
- self.relationships_vdb = self.vector_db_storage_cls( # type: ignore
512
  namespace=make_namespace(
513
  self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
514
  ),
515
  embedding_func=self.embedding_func,
516
  meta_fields={"src_id", "tgt_id"},
517
  )
518
- self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
519
  namespace=make_namespace(
520
  self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
521
  ),
@@ -535,16 +530,16 @@ class LightRAG:
535
  ):
536
  hashing_kv = self.llm_response_cache
537
  else:
538
- hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
539
  namespace=make_namespace(
540
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
541
  ),
542
  embedding_func=self.embedding_func,
543
  )
544
-
545
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
546
  partial(
547
- self.llm_model_func, # type: ignore
548
  hashing_kv=hashing_kv,
549
  **self.llm_model_kwargs,
550
  )
@@ -836,32 +831,32 @@ class LightRAG:
836
  raise e
837
 
838
  async def _insert_done(self):
839
- tasks = []
840
- for storage_inst in [
841
- self.full_docs,
842
- self.text_chunks,
843
- self.llm_response_cache,
844
- self.entities_vdb,
845
- self.relationships_vdb,
846
- self.chunks_vdb,
847
- self.chunk_entity_relation_graph,
848
- ]:
849
- if storage_inst is None:
850
- continue
851
- tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
852
  await asyncio.gather(*tasks)
853
 
854
- def insert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
855
  loop = always_get_an_event_loop()
856
  return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
857
 
858
- async def ainsert_custom_kg(self, custom_kg: dict[str, dict[str, str]]):
859
  update_storage = False
860
  try:
861
  # Insert chunks into vector storage
862
- all_chunks_data = {}
863
- chunk_to_source_map = {}
864
- for chunk_data in custom_kg.get("chunks", []):
865
  chunk_content = chunk_data["content"]
866
  source_id = chunk_data["source_id"]
867
  chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
@@ -871,13 +866,13 @@ class LightRAG:
871
  chunk_to_source_map[source_id] = chunk_id
872
  update_storage = True
873
 
874
- if self.chunks_vdb is not None and all_chunks_data:
875
  await self.chunks_vdb.upsert(all_chunks_data)
876
- if self.text_chunks is not None and all_chunks_data:
877
  await self.text_chunks.upsert(all_chunks_data)
878
 
879
  # Insert entities into knowledge graph
880
- all_entities_data = []
881
  for entity_data in custom_kg.get("entities", []):
882
  entity_name = f'"{entity_data["entity_name"].upper()}"'
883
  entity_type = entity_data.get("entity_type", "UNKNOWN")
@@ -893,7 +888,7 @@ class LightRAG:
893
  )
894
 
895
  # Prepare node data
896
- node_data = {
897
  "entity_type": entity_type,
898
  "description": description,
899
  "source_id": source_id,
@@ -907,7 +902,7 @@ class LightRAG:
907
  update_storage = True
908
 
909
  # Insert relationships into knowledge graph
910
- all_relationships_data = []
911
  for relationship_data in custom_kg.get("relationships", []):
912
  src_id = f'"{relationship_data["src_id"].upper()}"'
913
  tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
@@ -949,7 +944,7 @@ class LightRAG:
949
  "source_id": source_id,
950
  },
951
  )
952
- edge_data = {
953
  "src_id": src_id,
954
  "tgt_id": tgt_id,
955
  "description": description,
@@ -959,19 +954,17 @@ class LightRAG:
959
  update_storage = True
960
 
961
  # Insert entities into vector storage if needed
962
- if self.entities_vdb is not None:
963
- data_for_vdb = {
964
  compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
965
  "content": dp["entity_name"] + dp["description"],
966
  "entity_name": dp["entity_name"],
967
  }
968
  for dp in all_entities_data
969
  }
970
- await self.entities_vdb.upsert(data_for_vdb)
971
 
972
  # Insert relationships into vector storage if needed
973
- if self.relationships_vdb is not None:
974
- data_for_vdb = {
975
  compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
976
  "src_id": dp["src_id"],
977
  "tgt_id": dp["tgt_id"],
@@ -982,18 +975,49 @@ class LightRAG:
982
  }
983
  for dp in all_relationships_data
984
  }
985
- await self.relationships_vdb.upsert(data_for_vdb)
 
986
  finally:
987
  if update_storage:
988
  await self._insert_done()
989
 
990
- def query(self, query: str, prompt: str = "", param: QueryParam = QueryParam()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
991
  loop = always_get_an_event_loop()
992
- return loop.run_until_complete(self.aquery(query, prompt, param))
993
 
994
  async def aquery(
995
- self, query: str, prompt: str = "", param: QueryParam = QueryParam()
996
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
997
  if param.mode in ["local", "global", "hybrid"]:
998
  response = await kg_query(
999
  query,
 
6
  from dataclasses import asdict, dataclass, field
7
  from datetime import datetime
8
  from functools import partial
9
+ from typing import Any, Callable, Optional, Union, cast
10
 
11
  from .base import (
12
  BaseGraphStorage,
 
304
  - random_seed: Seed value for reproducibility.
305
  """
306
 
307
+ embedding_func: EmbeddingFunc | None = None
308
  """Function for computing text embeddings. Must be set before use."""
309
 
310
  embedding_batch_num: int = 32
 
344
 
345
  # Extensions
346
  addon_params: dict[str, Any] = field(default_factory=dict)
 
347
 
348
+ """Dictionary for additional parameters and extensions."""
 
349
  convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
350
  convert_response_to_json
351
  )
 
443
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
444
 
445
  # Init LLM
446
+ self.embedding_func = limit_async_func_call(self.embedding_func_max_async)( # type: ignore
447
  self.embedding_func
448
  )
449
 
450
  # Initialize all storages
451
+ self.key_string_value_json_storage_cls: type[BaseKVStorage] = (
452
  self._get_storage_class(self.kv_storage)
453
+ ) # type: ignore
454
+ self.vector_db_storage_cls: type[BaseVectorStorage] = self._get_storage_class(
455
  self.vector_storage
456
+ ) # type: ignore
457
+ self.graph_storage_cls: type[BaseGraphStorage] = self._get_storage_class(
458
  self.graph_storage
459
+ ) # type: ignore
460
+ self.key_string_value_json_storage_cls = partial( # type: ignore
 
461
  self.key_string_value_json_storage_cls, global_config=global_config
462
  )
463
+ self.vector_db_storage_cls = partial( # type: ignore
 
464
  self.vector_db_storage_cls, global_config=global_config
465
  )
466
+ self.graph_storage_cls = partial( # type: ignore
 
467
  self.graph_storage_cls, global_config=global_config
468
  )
469
 
470
  # Initialize document status storage
471
  self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
472
 
473
+ self.llm_response_cache: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
474
  namespace=make_namespace(
475
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
476
  ),
477
  embedding_func=self.embedding_func,
478
  )
479
 
480
+ self.full_docs: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
481
  namespace=make_namespace(
482
  self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS
483
  ),
484
  embedding_func=self.embedding_func,
485
  )
486
+ self.text_chunks: BaseKVStorage = self.key_string_value_json_storage_cls( # type: ignore
487
  namespace=make_namespace(
488
  self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS
489
  ),
490
  embedding_func=self.embedding_func,
491
  )
492
+ self.chunk_entity_relation_graph: BaseGraphStorage = self.graph_storage_cls( # type: ignore
493
  namespace=make_namespace(
494
  self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION
495
  ),
496
  embedding_func=self.embedding_func,
497
  )
498
 
499
+ self.entities_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
500
  namespace=make_namespace(
501
  self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES
502
  ),
503
  embedding_func=self.embedding_func,
504
  meta_fields={"entity_name"},
505
  )
506
+ self.relationships_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
507
  namespace=make_namespace(
508
  self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS
509
  ),
510
  embedding_func=self.embedding_func,
511
  meta_fields={"src_id", "tgt_id"},
512
  )
513
+ self.chunks_vdb: BaseVectorStorage = self.vector_db_storage_cls( # type: ignore
514
  namespace=make_namespace(
515
  self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS
516
  ),
 
530
  ):
531
  hashing_kv = self.llm_response_cache
532
  else:
533
+ hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
534
  namespace=make_namespace(
535
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
536
  ),
537
  embedding_func=self.embedding_func,
538
  )
539
+
540
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
541
  partial(
542
+ self.llm_model_func, # type: ignore
543
  hashing_kv=hashing_kv,
544
  **self.llm_model_kwargs,
545
  )
 
831
  raise e
832
 
833
  async def _insert_done(self):
834
+ tasks = [
835
+ cast(StorageNameSpace, storage_inst).index_done_callback()
836
+ for storage_inst in [ # type: ignore
837
+ self.full_docs,
838
+ self.text_chunks,
839
+ self.llm_response_cache,
840
+ self.entities_vdb,
841
+ self.relationships_vdb,
842
+ self.chunks_vdb,
843
+ self.chunk_entity_relation_graph,
844
+ ]
845
+ if storage_inst is not None
846
+ ]
847
  await asyncio.gather(*tasks)
848
 
849
+ def insert_custom_kg(self, custom_kg: dict[str, Any]):
850
  loop = always_get_an_event_loop()
851
  return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
852
 
853
+ async def ainsert_custom_kg(self, custom_kg: dict[str, Any]):
854
  update_storage = False
855
  try:
856
  # Insert chunks into vector storage
857
+ all_chunks_data: dict[str, dict[str, str]] = {}
858
+ chunk_to_source_map: dict[str, str] = {}
859
+ for chunk_data in custom_kg.get("chunks", {}):
860
  chunk_content = chunk_data["content"]
861
  source_id = chunk_data["source_id"]
862
  chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
 
866
  chunk_to_source_map[source_id] = chunk_id
867
  update_storage = True
868
 
869
+ if all_chunks_data:
870
  await self.chunks_vdb.upsert(all_chunks_data)
871
+ if all_chunks_data:
872
  await self.text_chunks.upsert(all_chunks_data)
873
 
874
  # Insert entities into knowledge graph
875
+ all_entities_data: list[dict[str, str]] = []
876
  for entity_data in custom_kg.get("entities", []):
877
  entity_name = f'"{entity_data["entity_name"].upper()}"'
878
  entity_type = entity_data.get("entity_type", "UNKNOWN")
 
888
  )
889
 
890
  # Prepare node data
891
+ node_data: dict[str, str] = {
892
  "entity_type": entity_type,
893
  "description": description,
894
  "source_id": source_id,
 
902
  update_storage = True
903
 
904
  # Insert relationships into knowledge graph
905
+ all_relationships_data: list[dict[str, str]] = []
906
  for relationship_data in custom_kg.get("relationships", []):
907
  src_id = f'"{relationship_data["src_id"].upper()}"'
908
  tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
 
944
  "source_id": source_id,
945
  },
946
  )
947
+ edge_data: dict[str, str] = {
948
  "src_id": src_id,
949
  "tgt_id": tgt_id,
950
  "description": description,
 
954
  update_storage = True
955
 
956
  # Insert entities into vector storage if needed
957
+ data_for_vdb = {
 
958
  compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
959
  "content": dp["entity_name"] + dp["description"],
960
  "entity_name": dp["entity_name"],
961
  }
962
  for dp in all_entities_data
963
  }
964
+ await self.entities_vdb.upsert(data_for_vdb)
965
 
966
  # Insert relationships into vector storage if needed
967
+ data_for_vdb = {
 
968
  compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
969
  "src_id": dp["src_id"],
970
  "tgt_id": dp["tgt_id"],
 
975
  }
976
  for dp in all_relationships_data
977
  }
978
+ await self.relationships_vdb.upsert(data_for_vdb)
979
+
980
  finally:
981
  if update_storage:
982
  await self._insert_done()
983
 
984
+ def query(
985
+ self,
986
+ query: str,
987
+ param: QueryParam = QueryParam(),
988
+ prompt: str | None = None
989
+ ) -> str:
990
+ """
991
+ Perform a sync query.
992
+
993
+ Args:
994
+ query (str): The query to be executed.
995
+ param (QueryParam): Configuration parameters for query execution.
996
+ prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
997
+
998
+ Returns:
999
+ str: The result of the query execution.
1000
+ """
1001
  loop = always_get_an_event_loop()
1002
+ return loop.run_until_complete(self.aquery(query, param, prompt))
1003
 
1004
  async def aquery(
1005
+ self,
1006
+ query: str,
1007
+ param: QueryParam = QueryParam(),
1008
+ prompt: str | None = None,
1009
+ ) -> str:
1010
+ """
1011
+ Perform a async query.
1012
+
1013
+ Args:
1014
+ query (str): The query to be executed.
1015
+ param (QueryParam): Configuration parameters for query execution.
1016
+ prompt (Optional[str]): Custom prompts for fine-tuned control over the system's behavior. Defaults to None, which uses PROMPTS["rag_response"].
1017
+
1018
+ Returns:
1019
+ str: The result of the query execution.
1020
+ """
1021
  if param.mode in ["local", "global", "hybrid"]:
1022
  response = await kg_query(
1023
  query,
lightrag/operate.py CHANGED
@@ -295,8 +295,8 @@ async def extract_entities(
295
  knowledge_graph_inst: BaseGraphStorage,
296
  entity_vdb: BaseVectorStorage,
297
  relationships_vdb: BaseVectorStorage,
298
- global_config: dict,
299
- llm_response_cache: BaseKVStorage = None,
300
  ) -> Union[BaseGraphStorage, None]:
301
  use_llm_func: callable = global_config["llm_model_func"]
302
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
@@ -563,15 +563,15 @@ async def extract_entities(
563
 
564
 
565
  async def kg_query(
566
- query,
567
  knowledge_graph_inst: BaseGraphStorage,
568
  entities_vdb: BaseVectorStorage,
569
  relationships_vdb: BaseVectorStorage,
570
  text_chunks_db: BaseKVStorage,
571
  query_param: QueryParam,
572
- global_config: dict,
573
- hashing_kv: BaseKVStorage = None,
574
- prompt: str = "",
575
  ) -> str:
576
  # Handle cache
577
  use_model_func = global_config["llm_model_func"]
@@ -681,8 +681,8 @@ async def kg_query(
681
  async def extract_keywords_only(
682
  text: str,
683
  param: QueryParam,
684
- global_config: dict,
685
- hashing_kv: BaseKVStorage = None,
686
  ) -> tuple[list[str], list[str]]:
687
  """
688
  Extract high-level and low-level keywords from the given 'text' using the LLM.
@@ -778,8 +778,8 @@ async def mix_kg_vector_query(
778
  chunks_vdb: BaseVectorStorage,
779
  text_chunks_db: BaseKVStorage,
780
  query_param: QueryParam,
781
- global_config: dict,
782
- hashing_kv: BaseKVStorage = None,
783
  ) -> str:
784
  """
785
  Hybrid retrieval implementation combining knowledge graph and vector search.
@@ -1499,12 +1499,12 @@ def combine_contexts(entities, relationships, sources):
1499
 
1500
 
1501
  async def naive_query(
1502
- query,
1503
  chunks_vdb: BaseVectorStorage,
1504
  text_chunks_db: BaseKVStorage,
1505
  query_param: QueryParam,
1506
- global_config: dict,
1507
- hashing_kv: BaseKVStorage = None,
1508
  ):
1509
  # Handle cache
1510
  use_model_func = global_config["llm_model_func"]
@@ -1606,8 +1606,8 @@ async def kg_query_with_keywords(
1606
  relationships_vdb: BaseVectorStorage,
1607
  text_chunks_db: BaseKVStorage,
1608
  query_param: QueryParam,
1609
- global_config: dict,
1610
- hashing_kv: BaseKVStorage = None,
1611
  ) -> str:
1612
  """
1613
  Refactored kg_query that does NOT extract keywords by itself.
 
295
  knowledge_graph_inst: BaseGraphStorage,
296
  entity_vdb: BaseVectorStorage,
297
  relationships_vdb: BaseVectorStorage,
298
+ global_config: dict[str, str],
299
+ llm_response_cache: BaseKVStorage | None = None,
300
  ) -> Union[BaseGraphStorage, None]:
301
  use_llm_func: callable = global_config["llm_model_func"]
302
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
 
563
 
564
 
565
  async def kg_query(
566
+ query: str,
567
  knowledge_graph_inst: BaseGraphStorage,
568
  entities_vdb: BaseVectorStorage,
569
  relationships_vdb: BaseVectorStorage,
570
  text_chunks_db: BaseKVStorage,
571
  query_param: QueryParam,
572
+ global_config: dict[str, str],
573
+ hashing_kv: BaseKVStorage | None = None,
574
+ prompt: str | None = None,
575
  ) -> str:
576
  # Handle cache
577
  use_model_func = global_config["llm_model_func"]
 
681
  async def extract_keywords_only(
682
  text: str,
683
  param: QueryParam,
684
+ global_config: dict[str, str],
685
+ hashing_kv: BaseKVStorage | None = None,
686
  ) -> tuple[list[str], list[str]]:
687
  """
688
  Extract high-level and low-level keywords from the given 'text' using the LLM.
 
778
  chunks_vdb: BaseVectorStorage,
779
  text_chunks_db: BaseKVStorage,
780
  query_param: QueryParam,
781
+ global_config: dict[str, str],
782
+ hashing_kv: BaseKVStorage | None = None,
783
  ) -> str:
784
  """
785
  Hybrid retrieval implementation combining knowledge graph and vector search.
 
1499
 
1500
 
1501
  async def naive_query(
1502
+ query: str,
1503
  chunks_vdb: BaseVectorStorage,
1504
  text_chunks_db: BaseKVStorage,
1505
  query_param: QueryParam,
1506
+ global_config: dict[str, str],
1507
+ hashing_kv: BaseKVStorage | None = None,
1508
  ):
1509
  # Handle cache
1510
  use_model_func = global_config["llm_model_func"]
 
1606
  relationships_vdb: BaseVectorStorage,
1607
  text_chunks_db: BaseKVStorage,
1608
  query_param: QueryParam,
1609
+ global_config: dict[str, str],
1610
+ hashing_kv: BaseKVStorage | None = None,
1611
  ) -> str:
1612
  """
1613
  Refactored kg_query that does NOT extract keywords by itself.
lightrag/utils.py CHANGED
@@ -128,7 +128,12 @@ def compute_args_hash(*args, cache_type: str = None) -> str:
128
  return hashlib.md5(args_str.encode()).hexdigest()
129
 
130
 
131
- def compute_mdhash_id(content, prefix: str = ""):
 
 
 
 
 
132
  return prefix + md5(content.encode()).hexdigest()
133
 
134
 
 
128
  return hashlib.md5(args_str.encode()).hexdigest()
129
 
130
 
131
+ def compute_mdhash_id(content: str, prefix: str = "") -> str:
132
+ """
133
+ Compute a unique ID for a given content string.
134
+
135
+ The ID is a combination of the given prefix and the MD5 hash of the content string.
136
+ """
137
  return prefix + md5(content.encode()).hexdigest()
138
 
139