alazarchuk commited on
Commit
0feb46d
·
2 Parent(s): 3aa449a 4c2ac8a

Merge remote-tracking branch 'origin/main' into fix-ollama-integration

Browse files
.gitignore CHANGED
@@ -1,121 +1,8 @@
1
- # Byte-compiled / optimized / DLL files
2
- __pycache__/
3
- *.py[cod]
4
- *$py.class
5
-
6
- # C extensions
7
- *.so
8
-
9
- # Distribution / packaging
10
- .Python
11
- env/
12
- venv/
13
- ENV/
14
- env.bak/
15
- venv.bak/
16
- *.egg
17
- *.egg-info/
18
- dist/
19
- build/
20
- *.whl
21
-
22
- # PyInstaller
23
- # Usually these files are written by a python script from a template
24
- # before PyInstaller builds the exe, so as to inject date/other infos into it.
25
- *.manifest
26
- *.spec
27
-
28
- # Installer logs
29
- pip-log.txt
30
- pip-delete-this-directory.txt
31
-
32
- # Unit test / coverage reports
33
- htmlcov/
34
- .tox/
35
- .nox/
36
- .coverage
37
- .coverage.*
38
- .cache
39
- nosetests.xml
40
- coverage.xml
41
- *.cover
42
- *.py,cover
43
- .hypothesis/
44
-
45
- # Translations
46
- *.mo
47
- *.pot
48
-
49
- # Django stuff:
50
- *.log
51
- local_settings.py
52
- db.sqlite3
53
- db.sqlite3-journal
54
-
55
- # Flask stuff:
56
- instance/
57
- .webassets-cache
58
-
59
- # Scrapy stuff:
60
- .scrapy
61
-
62
- # Sphinx documentation
63
- docs/_build/
64
-
65
- # PyBuilder
66
- target/
67
-
68
- # Jupyter Notebook
69
- .ipynb_checkpoints
70
-
71
- # IPython
72
- profile_default/
73
- ipython_config.py
74
-
75
- # pyenv
76
- .python-version
77
-
78
- # celery beat schedule file
79
- celerybeat-schedule
80
-
81
- # SageMath parsed files
82
- *.sage.py
83
-
84
- # Environments
85
- .env
86
- .env.*
87
- .venv
88
- .venv.*
89
- env/
90
- venv/
91
- ENV/
92
- env.bak/
93
- venv.bak/
94
-
95
- # Spyder project settings
96
- .spyderproject
97
- .spyderworkspace
98
-
99
- # Rope project settings
100
- .ropeproject
101
-
102
- # mkdocs documentation
103
- /site
104
-
105
- # mypy
106
- .mypy_cache/
107
- .dmypy.json
108
- dmypy.json
109
-
110
- # Pyre type checker
111
- .pyre/
112
-
113
- # pytype static type analyzer
114
- .pytype/
115
-
116
- # Cython debug symbols
117
- cython_debug/
118
-
119
- # Example files
120
- book.txt
121
  dickens/
 
 
 
 
 
 
1
+ __pycache__
2
+ *.egg-info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  dickens/
4
+ book.txt
5
+ lightrag-dev/
6
+ .idea/
7
+ dist/
8
+ .venv/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v5.0.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: end-of-file-fixer
7
+ - id: requirements-txt-fixer
8
+
9
+
10
+ - repo: https://github.com/astral-sh/ruff-pre-commit
11
+ rev: v0.6.4
12
+ hooks:
13
+ - id: ruff-format
14
+ - id: ruff
15
+ args: [--fix]
16
+
17
+
18
+ - repo: https://github.com/mgedmin/check-manifest
19
+ rev: "0.49"
20
+ hooks:
21
+ - id: check-manifest
22
+ stages: [manual]
README.md CHANGED
@@ -16,16 +16,17 @@
16
  <a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
17
  <a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
18
  </p>
19
-
20
  This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
21
  ![请添加图片描述](https://i-blog.csdnimg.cn/direct/b2aaf634151b4706892693ffb43d9093.png)
22
  </div>
23
 
24
- ## 🎉 News
 
25
  - [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
26
  - [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
27
- - [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
28
- - [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
29
 
30
  ## Install
31
 
@@ -47,12 +48,21 @@ pip install lightrag-hku
47
  ```bash
48
  curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_data.txt > ./book.txt
49
  ```
50
- Use the below Python snippet to initialize LightRAG and perform queries:
51
 
52
  ```python
53
  from lightrag import LightRAG, QueryParam
54
  from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
55
 
 
 
 
 
 
 
 
 
 
56
  WORKING_DIR = "./dickens"
57
 
58
  if not os.path.exists(WORKING_DIR):
@@ -83,7 +93,7 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
83
  <details>
84
  <summary> Using Open AI-like APIs </summary>
85
 
86
- LightRAG also support Open AI-like chat/embeddings APIs:
87
  ```python
88
  async def llm_model_func(
89
  prompt, system_prompt=None, history_messages=[], **kwargs
@@ -120,8 +130,8 @@ rag = LightRAG(
120
 
121
  <details>
122
  <summary> Using Hugging Face Models </summary>
123
-
124
- If you want to use Hugging Face models, you only need to set LightRAG as follows:
125
  ```python
126
  from lightrag.llm import hf_model_complete, hf_embedding
127
  from transformers import AutoModel, AutoTokenizer
@@ -136,7 +146,7 @@ rag = LightRAG(
136
  embedding_dim=384,
137
  max_token_size=5000,
138
  func=lambda texts: hf_embedding(
139
- texts,
140
  tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
141
  embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
142
  )
@@ -147,8 +157,9 @@ rag = LightRAG(
147
 
148
  <details>
149
  <summary> Using Ollama Models </summary>
150
- If you want to use Ollama models, you only need to set LightRAG as follows:
151
 
 
 
152
  ```python
153
  from lightrag.llm import ollama_model_complete, ollama_embedding
154
 
@@ -162,12 +173,35 @@ rag = LightRAG(
162
  embedding_dim=768,
163
  max_token_size=8192,
164
  func=lambda texts: ollama_embedding(
165
- texts,
166
  embed_model="nomic-embed-text"
167
  )
168
  ),
169
  )
170
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  </details>
172
 
173
  ### Batch Insert
@@ -185,16 +219,171 @@ rag = LightRAG(working_dir="./dickens")
185
  with open("./newText.txt") as f:
186
  rag.insert(f.read())
187
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  ## Evaluation
189
  ### Dataset
190
- The dataset used in LightRAG can be download from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
191
 
192
  ### Generate Query
193
- LightRAG uses the following prompt to generate high-level queries, with the corresponding code located in `example/generate_query.py`.
194
 
195
  <details>
196
  <summary> Prompt </summary>
197
-
198
  ```python
199
  Given the following description of a dataset:
200
 
@@ -219,18 +408,18 @@ Output the results in the following structure:
219
  ...
220
  ```
221
  </details>
222
-
223
  ### Batch Eval
224
  To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
225
 
226
  <details>
227
  <summary> Prompt </summary>
228
-
229
  ```python
230
  ---Role---
231
  You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
232
  ---Goal---
233
- You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
234
 
235
  - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
236
  - **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
@@ -294,7 +483,7 @@ Output your evaluation in the following JSON format:
294
  | **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% |
295
  | **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% |
296
 
297
- ## Reproduce
298
  All the code can be found in the `./reproduce` directory.
299
 
300
  ### Step-0 Extract Unique Contexts
@@ -302,7 +491,7 @@ First, we need to extract unique contexts in the datasets.
302
 
303
  <details>
304
  <summary> Code </summary>
305
-
306
  ```python
307
  def extract_unique_contexts(input_directory, output_directory):
308
 
@@ -361,12 +550,12 @@ For the extracted contexts, we insert them into the LightRAG system.
361
 
362
  <details>
363
  <summary> Code </summary>
364
-
365
  ```python
366
  def insert_text(rag, file_path):
367
  with open(file_path, mode='r') as f:
368
  unique_contexts = json.load(f)
369
-
370
  retries = 0
371
  max_retries = 3
372
  while retries < max_retries:
@@ -384,11 +573,11 @@ def insert_text(rag, file_path):
384
 
385
  ### Step-2 Generate Queries
386
 
387
- We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries.
388
 
389
  <details>
390
  <summary> Code </summary>
391
-
392
  ```python
393
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
394
 
@@ -401,7 +590,7 @@ def get_summary(context, tot_tokens=2000):
401
 
402
  summary_tokens = start_tokens + end_tokens
403
  summary = tokenizer.convert_tokens_to_string(summary_tokens)
404
-
405
  return summary
406
  ```
407
  </details>
@@ -411,12 +600,12 @@ For the queries generated in Step-2, we will extract them and query LightRAG.
411
 
412
  <details>
413
  <summary> Code </summary>
414
-
415
  ```python
416
  def extract_queries(file_path):
417
  with open(file_path, 'r') as f:
418
  data = f.read()
419
-
420
  data = data.replace('**', '')
421
 
422
  queries = re.findall(r'- Question \d+: (.+)', data)
@@ -431,11 +620,16 @@ def extract_queries(file_path):
431
  .
432
  ├── examples
433
  │ ├── batch_eval.py
 
 
434
  │ ├── generate_query.py
 
 
435
  │ ├── lightrag_hf_demo.py
436
  │ ├── lightrag_ollama_demo.py
437
  │ ├── lightrag_openai_compatible_demo.py
438
- └── lightrag_openai_demo.py
 
439
  ├── lightrag
440
  │ ├── __init__.py
441
  │ ├── base.py
@@ -450,6 +644,8 @@ def extract_queries(file_path):
450
  │ ├── Step_1.py
451
  │ ├── Step_2.py
452
  │ └── Step_3.py
 
 
453
  ├── LICENSE
454
  ├── README.md
455
  ├── requirements.txt
@@ -470,7 +666,7 @@ def extract_queries(file_path):
470
 
471
  ```python
472
  @article{guo2024lightrag,
473
- title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
474
  author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
475
  year={2024},
476
  eprint={2410.05779},
 
16
  <a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
17
  <a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
18
  </p>
19
+
20
  This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
21
  ![请添加图片描述](https://i-blog.csdnimg.cn/direct/b2aaf634151b4706892693ffb43d9093.png)
22
  </div>
23
 
24
+ ## 🎉 News
25
+ - [x] [2024.10.20]🎯🎯📢📢We’ve added a new feature to LightRAG: Graph Visualization.
26
  - [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
27
  - [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
28
+ - [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
29
+ - [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
30
 
31
  ## Install
32
 
 
48
  ```bash
49
  curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_data.txt > ./book.txt
50
  ```
51
+ Use the below Python snippet (in a script) to initialize LightRAG and perform queries:
52
 
53
  ```python
54
  from lightrag import LightRAG, QueryParam
55
  from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
56
 
57
+ #########
58
+ # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
59
+ # import nest_asyncio
60
+ # nest_asyncio.apply()
61
+ #########
62
+
63
+ WORKING_DIR = "./dickens"
64
+
65
+
66
  WORKING_DIR = "./dickens"
67
 
68
  if not os.path.exists(WORKING_DIR):
 
93
  <details>
94
  <summary> Using Open AI-like APIs </summary>
95
 
96
+ * LightRAG also supports Open AI-like chat/embeddings APIs:
97
  ```python
98
  async def llm_model_func(
99
  prompt, system_prompt=None, history_messages=[], **kwargs
 
130
 
131
  <details>
132
  <summary> Using Hugging Face Models </summary>
133
+
134
+ * If you want to use Hugging Face models, you only need to set LightRAG as follows:
135
  ```python
136
  from lightrag.llm import hf_model_complete, hf_embedding
137
  from transformers import AutoModel, AutoTokenizer
 
146
  embedding_dim=384,
147
  max_token_size=5000,
148
  func=lambda texts: hf_embedding(
149
+ texts,
150
  tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
151
  embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
152
  )
 
157
 
158
  <details>
159
  <summary> Using Ollama Models </summary>
 
160
 
161
+ * If you want to use Ollama models, you only need to set LightRAG as follows:
162
+
163
  ```python
164
  from lightrag.llm import ollama_model_complete, ollama_embedding
165
 
 
173
  embedding_dim=768,
174
  max_token_size=8192,
175
  func=lambda texts: ollama_embedding(
176
+ texts,
177
  embed_model="nomic-embed-text"
178
  )
179
  ),
180
  )
181
  ```
182
+
183
+ * Increasing the `num_ctx` parameter:
184
+
185
+ 1. Pull the model:
186
+ ```python
187
+ ollama pull qwen2
188
+ ```
189
+
190
+ 2. Display the model file:
191
+ ```python
192
+ ollama show --modelfile qwen2 > Modelfile
193
+ ```
194
+
195
+ 3. Edit the Modelfile by adding the following line:
196
+ ```python
197
+ PARAMETER num_ctx 32768
198
+ ```
199
+
200
+ 4. Create the modified model:
201
+ ```python
202
+ ollama create -f Modelfile qwen2m
203
+ ```
204
+
205
  </details>
206
 
207
  ### Batch Insert
 
219
  with open("./newText.txt") as f:
220
  rag.insert(f.read())
221
  ```
222
+
223
+ ### Graph Visualization
224
+
225
+ <details>
226
+ <summary> Graph visualization with html </summary>
227
+
228
+ * The following code can be found in `examples/graph_visual_with_html.py`
229
+
230
+ ```python
231
+ import networkx as nx
232
+ from pyvis.network import Network
233
+
234
+ # Load the GraphML file
235
+ G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
236
+
237
+ # Create a Pyvis network
238
+ net = Network(notebook=True)
239
+
240
+ # Convert NetworkX graph to Pyvis network
241
+ net.from_nx(G)
242
+
243
+ # Save and display the network
244
+ net.show('knowledge_graph.html')
245
+ ```
246
+
247
+ </details>
248
+
249
+ <details>
250
+ <summary> Graph visualization with Neo4j </summary>
251
+
252
+ * The following code can be found in `examples/graph_visual_with_neo4j.py`
253
+
254
+ ```python
255
+ import os
256
+ import json
257
+ from lightrag.utils import xml_to_json
258
+ from neo4j import GraphDatabase
259
+
260
+ # Constants
261
+ WORKING_DIR = "./dickens"
262
+ BATCH_SIZE_NODES = 500
263
+ BATCH_SIZE_EDGES = 100
264
+
265
+ # Neo4j connection credentials
266
+ NEO4J_URI = "bolt://localhost:7687"
267
+ NEO4J_USERNAME = "neo4j"
268
+ NEO4J_PASSWORD = "your_password"
269
+
270
+ def convert_xml_to_json(xml_path, output_path):
271
+ """Converts XML file to JSON and saves the output."""
272
+ if not os.path.exists(xml_path):
273
+ print(f"Error: File not found - {xml_path}")
274
+ return None
275
+
276
+ json_data = xml_to_json(xml_path)
277
+ if json_data:
278
+ with open(output_path, 'w', encoding='utf-8') as f:
279
+ json.dump(json_data, f, ensure_ascii=False, indent=2)
280
+ print(f"JSON file created: {output_path}")
281
+ return json_data
282
+ else:
283
+ print("Failed to create JSON data")
284
+ return None
285
+
286
+ def process_in_batches(tx, query, data, batch_size):
287
+ """Process data in batches and execute the given query."""
288
+ for i in range(0, len(data), batch_size):
289
+ batch = data[i:i + batch_size]
290
+ tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
291
+
292
+ def main():
293
+ # Paths
294
+ xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
295
+ json_file = os.path.join(WORKING_DIR, 'graph_data.json')
296
+
297
+ # Convert XML to JSON
298
+ json_data = convert_xml_to_json(xml_file, json_file)
299
+ if json_data is None:
300
+ return
301
+
302
+ # Load nodes and edges
303
+ nodes = json_data.get('nodes', [])
304
+ edges = json_data.get('edges', [])
305
+
306
+ # Neo4j queries
307
+ create_nodes_query = """
308
+ UNWIND $nodes AS node
309
+ MERGE (e:Entity {id: node.id})
310
+ SET e.entity_type = node.entity_type,
311
+ e.description = node.description,
312
+ e.source_id = node.source_id,
313
+ e.displayName = node.id
314
+ REMOVE e:Entity
315
+ WITH e, node
316
+ CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
317
+ RETURN count(*)
318
+ """
319
+
320
+ create_edges_query = """
321
+ UNWIND $edges AS edge
322
+ MATCH (source {id: edge.source})
323
+ MATCH (target {id: edge.target})
324
+ WITH source, target, edge,
325
+ CASE
326
+ WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
327
+ WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
328
+ WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
329
+ WHEN edge.keywords CONTAINS 'located' THEN 'located'
330
+ WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
331
+ ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
332
+ END AS relType
333
+ CALL apoc.create.relationship(source, relType, {
334
+ weight: edge.weight,
335
+ description: edge.description,
336
+ keywords: edge.keywords,
337
+ source_id: edge.source_id
338
+ }, target) YIELD rel
339
+ RETURN count(*)
340
+ """
341
+
342
+ set_displayname_and_labels_query = """
343
+ MATCH (n)
344
+ SET n.displayName = n.id
345
+ WITH n
346
+ CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
347
+ RETURN count(*)
348
+ """
349
+
350
+ # Create a Neo4j driver
351
+ driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
352
+
353
+ try:
354
+ # Execute queries in batches
355
+ with driver.session() as session:
356
+ # Insert nodes in batches
357
+ session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
358
+
359
+ # Insert edges in batches
360
+ session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
361
+
362
+ # Set displayName and labels
363
+ session.run(set_displayname_and_labels_query)
364
+
365
+ except Exception as e:
366
+ print(f"Error occurred: {e}")
367
+
368
+ finally:
369
+ driver.close()
370
+
371
+ if __name__ == "__main__":
372
+ main()
373
+ ```
374
+
375
+ </details>
376
+
377
  ## Evaluation
378
  ### Dataset
379
+ The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
380
 
381
  ### Generate Query
382
+ LightRAG uses the following prompt to generate high-level queries, with the corresponding code in `example/generate_query.py`.
383
 
384
  <details>
385
  <summary> Prompt </summary>
386
+
387
  ```python
388
  Given the following description of a dataset:
389
 
 
408
  ...
409
  ```
410
  </details>
411
+
412
  ### Batch Eval
413
  To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
414
 
415
  <details>
416
  <summary> Prompt </summary>
417
+
418
  ```python
419
  ---Role---
420
  You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
421
  ---Goal---
422
+ You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
423
 
424
  - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
425
  - **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
 
483
  | **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% |
484
  | **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% |
485
 
486
+ ## Reproduce
487
  All the code can be found in the `./reproduce` directory.
488
 
489
  ### Step-0 Extract Unique Contexts
 
491
 
492
  <details>
493
  <summary> Code </summary>
494
+
495
  ```python
496
  def extract_unique_contexts(input_directory, output_directory):
497
 
 
550
 
551
  <details>
552
  <summary> Code </summary>
553
+
554
  ```python
555
  def insert_text(rag, file_path):
556
  with open(file_path, mode='r') as f:
557
  unique_contexts = json.load(f)
558
+
559
  retries = 0
560
  max_retries = 3
561
  while retries < max_retries:
 
573
 
574
  ### Step-2 Generate Queries
575
 
576
+ We extract tokens from the first and the second half of each context in the dataset, then combine them as dataset descriptions to generate queries.
577
 
578
  <details>
579
  <summary> Code </summary>
580
+
581
  ```python
582
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
583
 
 
590
 
591
  summary_tokens = start_tokens + end_tokens
592
  summary = tokenizer.convert_tokens_to_string(summary_tokens)
593
+
594
  return summary
595
  ```
596
  </details>
 
600
 
601
  <details>
602
  <summary> Code </summary>
603
+
604
  ```python
605
  def extract_queries(file_path):
606
  with open(file_path, 'r') as f:
607
  data = f.read()
608
+
609
  data = data.replace('**', '')
610
 
611
  queries = re.findall(r'- Question \d+: (.+)', data)
 
620
  .
621
  ├── examples
622
  │ ├── batch_eval.py
623
+ │ ├── graph_visual_with_html.py
624
+ │ ├── graph_visual_with_neo4j.py
625
  │ ├── generate_query.py
626
+ │ ├── lightrag_azure_openai_demo.py
627
+ │ ├── lightrag_bedrock_demo.py
628
  │ ├── lightrag_hf_demo.py
629
  │ ├── lightrag_ollama_demo.py
630
  │ ├── lightrag_openai_compatible_demo.py
631
+ ├── lightrag_openai_demo.py
632
+ │ └── vram_management_demo.py
633
  ├── lightrag
634
  │ ├── __init__.py
635
  │ ├── base.py
 
644
  │ ├── Step_1.py
645
  │ ├── Step_2.py
646
  │ └── Step_3.py
647
+ ├── .gitignore
648
+ ├── .pre-commit-config.yaml
649
  ├── LICENSE
650
  ├── README.md
651
  ├── requirements.txt
 
666
 
667
  ```python
668
  @article{guo2024lightrag,
669
+ title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
670
  author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
671
  year={2024},
672
  eprint={2410.05779},
examples/batch_eval.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import re
3
  import json
4
  import jsonlines
@@ -9,28 +8,28 @@ from openai import OpenAI
9
  def batch_eval(query_file, result1_file, result2_file, output_file_path):
10
  client = OpenAI()
11
 
12
- with open(query_file, 'r') as f:
13
  data = f.read()
14
 
15
- queries = re.findall(r'- Question \d+: (.+)', data)
16
 
17
- with open(result1_file, 'r') as f:
18
  answers1 = json.load(f)
19
- answers1 = [i['result'] for i in answers1]
20
 
21
- with open(result2_file, 'r') as f:
22
  answers2 = json.load(f)
23
- answers2 = [i['result'] for i in answers2]
24
 
25
  requests = []
26
  for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
27
- sys_prompt = f"""
28
  ---Role---
29
  You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
30
  """
31
 
32
  prompt = f"""
33
- You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
34
 
35
  - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
36
  - **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
@@ -69,7 +68,6 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
69
  }}
70
  """
71
 
72
-
73
  request_data = {
74
  "custom_id": f"request-{i+1}",
75
  "method": "POST",
@@ -78,22 +76,21 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
78
  "model": "gpt-4o-mini",
79
  "messages": [
80
  {"role": "system", "content": sys_prompt},
81
- {"role": "user", "content": prompt}
82
  ],
83
- }
84
  }
85
-
86
  requests.append(request_data)
87
 
88
- with jsonlines.open(output_file_path, mode='w') as writer:
89
  for request in requests:
90
  writer.write(request)
91
 
92
  print(f"Batch API requests written to {output_file_path}")
93
 
94
  batch_input_file = client.files.create(
95
- file=open(output_file_path, "rb"),
96
- purpose="batch"
97
  )
98
  batch_input_file_id = batch_input_file.id
99
 
@@ -101,12 +98,11 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
101
  input_file_id=batch_input_file_id,
102
  endpoint="/v1/chat/completions",
103
  completion_window="24h",
104
- metadata={
105
- "description": "nightly eval job"
106
- }
107
  )
108
 
109
- print(f'Batch {batch.id} has been created.')
 
110
 
111
  if __name__ == "__main__":
112
- batch_eval()
 
 
1
  import re
2
  import json
3
  import jsonlines
 
8
  def batch_eval(query_file, result1_file, result2_file, output_file_path):
9
  client = OpenAI()
10
 
11
+ with open(query_file, "r") as f:
12
  data = f.read()
13
 
14
+ queries = re.findall(r"- Question \d+: (.+)", data)
15
 
16
+ with open(result1_file, "r") as f:
17
  answers1 = json.load(f)
18
+ answers1 = [i["result"] for i in answers1]
19
 
20
+ with open(result2_file, "r") as f:
21
  answers2 = json.load(f)
22
+ answers2 = [i["result"] for i in answers2]
23
 
24
  requests = []
25
  for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
26
+ sys_prompt = """
27
  ---Role---
28
  You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
29
  """
30
 
31
  prompt = f"""
32
+ You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
33
 
34
  - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
35
  - **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
 
68
  }}
69
  """
70
 
 
71
  request_data = {
72
  "custom_id": f"request-{i+1}",
73
  "method": "POST",
 
76
  "model": "gpt-4o-mini",
77
  "messages": [
78
  {"role": "system", "content": sys_prompt},
79
+ {"role": "user", "content": prompt},
80
  ],
81
+ },
82
  }
83
+
84
  requests.append(request_data)
85
 
86
+ with jsonlines.open(output_file_path, mode="w") as writer:
87
  for request in requests:
88
  writer.write(request)
89
 
90
  print(f"Batch API requests written to {output_file_path}")
91
 
92
  batch_input_file = client.files.create(
93
+ file=open(output_file_path, "rb"), purpose="batch"
 
94
  )
95
  batch_input_file_id = batch_input_file.id
96
 
 
98
  input_file_id=batch_input_file_id,
99
  endpoint="/v1/chat/completions",
100
  completion_window="24h",
101
+ metadata={"description": "nightly eval job"},
 
 
102
  )
103
 
104
+ print(f"Batch {batch.id} has been created.")
105
+
106
 
107
  if __name__ == "__main__":
108
+ batch_eval()
examples/generate_query.py CHANGED
@@ -1,9 +1,8 @@
1
- import os
2
-
3
  from openai import OpenAI
4
 
5
  # os.environ["OPENAI_API_KEY"] = ""
6
 
 
7
  def openai_complete_if_cache(
8
  model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
9
  ) -> str:
@@ -47,10 +46,10 @@ if __name__ == "__main__":
47
  ...
48
  """
49
 
50
- result = openai_complete_if_cache(model='gpt-4o-mini', prompt=prompt)
51
 
52
- file_path = f"./queries.txt"
53
  with open(file_path, "w") as file:
54
  file.write(result)
55
 
56
- print(f"Queries written to {file_path}")
 
 
 
1
  from openai import OpenAI
2
 
3
  # os.environ["OPENAI_API_KEY"] = ""
4
 
5
+
6
  def openai_complete_if_cache(
7
  model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
8
  ) -> str:
 
46
  ...
47
  """
48
 
49
+ result = openai_complete_if_cache(model="gpt-4o-mini", prompt=prompt)
50
 
51
+ file_path = "./queries.txt"
52
  with open(file_path, "w") as file:
53
  file.write(result)
54
 
55
+ print(f"Queries written to {file_path}")
examples/graph_visual_with_html.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+ from pyvis.network import Network
3
+ import random
4
+
5
+ # Load the GraphML file
6
+ G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
7
+
8
+ # Create a Pyvis network
9
+ net = Network(notebook=True)
10
+
11
+ # Convert NetworkX graph to Pyvis network
12
+ net.from_nx(G)
13
+
14
+ # Add colors to nodes
15
+ for node in net.nodes:
16
+ node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
17
+
18
+ # Save and display the network
19
+ net.show('knowledge_graph.html')
examples/graph_visual_with_neo4j.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from lightrag.utils import xml_to_json
4
+ from neo4j import GraphDatabase
5
+
6
+ # Constants
7
+ WORKING_DIR = "./dickens"
8
+ BATCH_SIZE_NODES = 500
9
+ BATCH_SIZE_EDGES = 100
10
+
11
+ # Neo4j connection credentials
12
+ NEO4J_URI = "bolt://localhost:7687"
13
+ NEO4J_USERNAME = "neo4j"
14
+ NEO4J_PASSWORD = "your_password"
15
+
16
+ def convert_xml_to_json(xml_path, output_path):
17
+ """Converts XML file to JSON and saves the output."""
18
+ if not os.path.exists(xml_path):
19
+ print(f"Error: File not found - {xml_path}")
20
+ return None
21
+
22
+ json_data = xml_to_json(xml_path)
23
+ if json_data:
24
+ with open(output_path, 'w', encoding='utf-8') as f:
25
+ json.dump(json_data, f, ensure_ascii=False, indent=2)
26
+ print(f"JSON file created: {output_path}")
27
+ return json_data
28
+ else:
29
+ print("Failed to create JSON data")
30
+ return None
31
+
32
+ def process_in_batches(tx, query, data, batch_size):
33
+ """Process data in batches and execute the given query."""
34
+ for i in range(0, len(data), batch_size):
35
+ batch = data[i:i + batch_size]
36
+ tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
37
+
38
+ def main():
39
+ # Paths
40
+ xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
41
+ json_file = os.path.join(WORKING_DIR, 'graph_data.json')
42
+
43
+ # Convert XML to JSON
44
+ json_data = convert_xml_to_json(xml_file, json_file)
45
+ if json_data is None:
46
+ return
47
+
48
+ # Load nodes and edges
49
+ nodes = json_data.get('nodes', [])
50
+ edges = json_data.get('edges', [])
51
+
52
+ # Neo4j queries
53
+ create_nodes_query = """
54
+ UNWIND $nodes AS node
55
+ MERGE (e:Entity {id: node.id})
56
+ SET e.entity_type = node.entity_type,
57
+ e.description = node.description,
58
+ e.source_id = node.source_id,
59
+ e.displayName = node.id
60
+ REMOVE e:Entity
61
+ WITH e, node
62
+ CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
63
+ RETURN count(*)
64
+ """
65
+
66
+ create_edges_query = """
67
+ UNWIND $edges AS edge
68
+ MATCH (source {id: edge.source})
69
+ MATCH (target {id: edge.target})
70
+ WITH source, target, edge,
71
+ CASE
72
+ WHEN edge.keywords CONTAINS 'lead' THEN 'lead'
73
+ WHEN edge.keywords CONTAINS 'participate' THEN 'participate'
74
+ WHEN edge.keywords CONTAINS 'uses' THEN 'uses'
75
+ WHEN edge.keywords CONTAINS 'located' THEN 'located'
76
+ WHEN edge.keywords CONTAINS 'occurs' THEN 'occurs'
77
+ ELSE REPLACE(SPLIT(edge.keywords, ',')[0], '\"', '')
78
+ END AS relType
79
+ CALL apoc.create.relationship(source, relType, {
80
+ weight: edge.weight,
81
+ description: edge.description,
82
+ keywords: edge.keywords,
83
+ source_id: edge.source_id
84
+ }, target) YIELD rel
85
+ RETURN count(*)
86
+ """
87
+
88
+ set_displayname_and_labels_query = """
89
+ MATCH (n)
90
+ SET n.displayName = n.id
91
+ WITH n
92
+ CALL apoc.create.setLabels(n, [n.entity_type]) YIELD node
93
+ RETURN count(*)
94
+ """
95
+
96
+ # Create a Neo4j driver
97
+ driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
98
+
99
+ try:
100
+ # Execute queries in batches
101
+ with driver.session() as session:
102
+ # Insert nodes in batches
103
+ session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
104
+
105
+ # Insert edges in batches
106
+ session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
107
+
108
+ # Set displayName and labels
109
+ session.run(set_displayname_and_labels_query)
110
+
111
+ except Exception as e:
112
+ print(f"Error occurred: {e}")
113
+
114
+ finally:
115
+ driver.close()
116
+
117
+ if __name__ == "__main__":
118
+ main()
examples/lightrag_azure_openai_demo.py CHANGED
@@ -122,4 +122,4 @@ print("\nResult (Global):")
122
  print(rag.query(query_text, param=QueryParam(mode="global")))
123
 
124
  print("\nResult (Hybrid):")
125
- print(rag.query(query_text, param=QueryParam(mode="hybrid")))
 
122
  print(rag.query(query_text, param=QueryParam(mode="global")))
123
 
124
  print("\nResult (Hybrid):")
125
+ print(rag.query(query_text, param=QueryParam(mode="hybrid")))
examples/lightrag_bedrock_demo.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LightRAG meets Amazon Bedrock ⛰️
3
+ """
4
+
5
+ import os
6
+ import logging
7
+
8
+ from lightrag import LightRAG, QueryParam
9
+ from lightrag.llm import bedrock_complete, bedrock_embedding
10
+ from lightrag.utils import EmbeddingFunc
11
+
12
+ logging.getLogger("aiobotocore").setLevel(logging.WARNING)
13
+
14
+ WORKING_DIR = "./dickens"
15
+ if not os.path.exists(WORKING_DIR):
16
+ os.mkdir(WORKING_DIR)
17
+
18
+ rag = LightRAG(
19
+ working_dir=WORKING_DIR,
20
+ llm_model_func=bedrock_complete,
21
+ llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
22
+ embedding_func=EmbeddingFunc(
23
+ embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
24
+ ),
25
+ )
26
+
27
+ with open("./book.txt", "r", encoding="utf-8") as f:
28
+ rag.insert(f.read())
29
+
30
+ for mode in ["naive", "local", "global", "hybrid"]:
31
+ print("\n+-" + "-" * len(mode) + "-+")
32
+ print(f"| {mode.capitalize()} |")
33
+ print("+-" + "-" * len(mode) + "-+\n")
34
+ print(
35
+ rag.query("What are the top themes in this story?", param=QueryParam(mode=mode))
36
+ )
examples/lightrag_hf_demo.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
- import sys
3
 
4
  from lightrag import LightRAG, QueryParam
5
  from lightrag.llm import hf_model_complete, hf_embedding
6
  from lightrag.utils import EmbeddingFunc
7
- from transformers import AutoModel,AutoTokenizer
8
 
9
  WORKING_DIR = "./dickens"
10
 
@@ -13,16 +12,20 @@ if not os.path.exists(WORKING_DIR):
13
 
14
  rag = LightRAG(
15
  working_dir=WORKING_DIR,
16
- llm_model_func=hf_model_complete,
17
- llm_model_name='meta-llama/Llama-3.1-8B-Instruct',
18
  embedding_func=EmbeddingFunc(
19
  embedding_dim=384,
20
  max_token_size=5000,
21
  func=lambda texts: hf_embedding(
22
- texts,
23
- tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
24
- embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
25
- )
 
 
 
 
26
  ),
27
  )
28
 
@@ -31,13 +34,21 @@ with open("./book.txt") as f:
31
  rag.insert(f.read())
32
 
33
  # Perform naive search
34
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
 
 
35
 
36
  # Perform local search
37
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
 
 
38
 
39
  # Perform global search
40
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
 
 
41
 
42
  # Perform hybrid search
43
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
 
 
 
1
  import os
 
2
 
3
  from lightrag import LightRAG, QueryParam
4
  from lightrag.llm import hf_model_complete, hf_embedding
5
  from lightrag.utils import EmbeddingFunc
6
+ from transformers import AutoModel, AutoTokenizer
7
 
8
  WORKING_DIR = "./dickens"
9
 
 
12
 
13
  rag = LightRAG(
14
  working_dir=WORKING_DIR,
15
+ llm_model_func=hf_model_complete,
16
+ llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
17
  embedding_func=EmbeddingFunc(
18
  embedding_dim=384,
19
  max_token_size=5000,
20
  func=lambda texts: hf_embedding(
21
+ texts,
22
+ tokenizer=AutoTokenizer.from_pretrained(
23
+ "sentence-transformers/all-MiniLM-L6-v2"
24
+ ),
25
+ embed_model=AutoModel.from_pretrained(
26
+ "sentence-transformers/all-MiniLM-L6-v2"
27
+ ),
28
+ ),
29
  ),
30
  )
31
 
 
34
  rag.insert(f.read())
35
 
36
  # Perform naive search
37
+ print(
38
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
39
+ )
40
 
41
  # Perform local search
42
+ print(
43
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
44
+ )
45
 
46
  # Perform global search
47
+ print(
48
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
49
+ )
50
 
51
  # Perform hybrid search
52
+ print(
53
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
54
+ )
examples/lightrag_openai_compatible_demo.py CHANGED
@@ -6,10 +6,11 @@ from lightrag.utils import EmbeddingFunc
6
  import numpy as np
7
 
8
  WORKING_DIR = "./dickens"
9
-
10
  if not os.path.exists(WORKING_DIR):
11
  os.mkdir(WORKING_DIR)
12
 
 
13
  async def llm_model_func(
14
  prompt, system_prompt=None, history_messages=[], **kwargs
15
  ) -> str:
@@ -20,17 +21,19 @@ async def llm_model_func(
20
  history_messages=history_messages,
21
  api_key=os.getenv("UPSTAGE_API_KEY"),
22
  base_url="https://api.upstage.ai/v1/solar",
23
- **kwargs
24
  )
25
 
 
26
  async def embedding_func(texts: list[str]) -> np.ndarray:
27
  return await openai_embedding(
28
  texts,
29
  model="solar-embedding-1-large-query",
30
  api_key=os.getenv("UPSTAGE_API_KEY"),
31
- base_url="https://api.upstage.ai/v1/solar"
32
  )
33
 
 
34
  # function test
35
  async def test_funcs():
36
  result = await llm_model_func("How are you?")
@@ -39,6 +42,7 @@ async def test_funcs():
39
  result = await embedding_func(["How are you?"])
40
  print("embedding_func: ", result)
41
 
 
42
  asyncio.run(test_funcs())
43
 
44
 
@@ -46,10 +50,8 @@ rag = LightRAG(
46
  working_dir=WORKING_DIR,
47
  llm_model_func=llm_model_func,
48
  embedding_func=EmbeddingFunc(
49
- embedding_dim=4096,
50
- max_token_size=8192,
51
- func=embedding_func
52
- )
53
  )
54
 
55
 
@@ -57,13 +59,21 @@ with open("./book.txt") as f:
57
  rag.insert(f.read())
58
 
59
  # Perform naive search
60
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
 
 
61
 
62
  # Perform local search
63
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
 
 
64
 
65
  # Perform global search
66
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
 
 
67
 
68
  # Perform hybrid search
69
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
 
 
 
6
  import numpy as np
7
 
8
  WORKING_DIR = "./dickens"
9
+
10
  if not os.path.exists(WORKING_DIR):
11
  os.mkdir(WORKING_DIR)
12
 
13
+
14
  async def llm_model_func(
15
  prompt, system_prompt=None, history_messages=[], **kwargs
16
  ) -> str:
 
21
  history_messages=history_messages,
22
  api_key=os.getenv("UPSTAGE_API_KEY"),
23
  base_url="https://api.upstage.ai/v1/solar",
24
+ **kwargs,
25
  )
26
 
27
+
28
  async def embedding_func(texts: list[str]) -> np.ndarray:
29
  return await openai_embedding(
30
  texts,
31
  model="solar-embedding-1-large-query",
32
  api_key=os.getenv("UPSTAGE_API_KEY"),
33
+ base_url="https://api.upstage.ai/v1/solar",
34
  )
35
 
36
+
37
  # function test
38
  async def test_funcs():
39
  result = await llm_model_func("How are you?")
 
42
  result = await embedding_func(["How are you?"])
43
  print("embedding_func: ", result)
44
 
45
+
46
  asyncio.run(test_funcs())
47
 
48
 
 
50
  working_dir=WORKING_DIR,
51
  llm_model_func=llm_model_func,
52
  embedding_func=EmbeddingFunc(
53
+ embedding_dim=4096, max_token_size=8192, func=embedding_func
54
+ ),
 
 
55
  )
56
 
57
 
 
59
  rag.insert(f.read())
60
 
61
  # Perform naive search
62
+ print(
63
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
64
+ )
65
 
66
  # Perform local search
67
+ print(
68
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
69
+ )
70
 
71
  # Perform global search
72
+ print(
73
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
74
+ )
75
 
76
  # Perform hybrid search
77
+ print(
78
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
79
+ )
examples/lightrag_openai_demo.py CHANGED
@@ -1,9 +1,7 @@
1
  import os
2
- import sys
3
 
4
  from lightrag import LightRAG, QueryParam
5
- from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
6
- from transformers import AutoModel,AutoTokenizer
7
 
8
  WORKING_DIR = "./dickens"
9
 
@@ -12,7 +10,7 @@ if not os.path.exists(WORKING_DIR):
12
 
13
  rag = LightRAG(
14
  working_dir=WORKING_DIR,
15
- llm_model_func=gpt_4o_mini_complete
16
  # llm_model_func=gpt_4o_complete
17
  )
18
 
@@ -21,13 +19,21 @@ with open("./book.txt") as f:
21
  rag.insert(f.read())
22
 
23
  # Perform naive search
24
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
 
 
25
 
26
  # Perform local search
27
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
 
 
28
 
29
  # Perform global search
30
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
 
 
31
 
32
  # Perform hybrid search
33
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
 
 
 
1
  import os
 
2
 
3
  from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm import gpt_4o_mini_complete
 
5
 
6
  WORKING_DIR = "./dickens"
7
 
 
10
 
11
  rag = LightRAG(
12
  working_dir=WORKING_DIR,
13
+ llm_model_func=gpt_4o_mini_complete,
14
  # llm_model_func=gpt_4o_complete
15
  )
16
 
 
19
  rag.insert(f.read())
20
 
21
  # Perform naive search
22
+ print(
23
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
24
+ )
25
 
26
  # Perform local search
27
+ print(
28
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
29
+ )
30
 
31
  # Perform global search
32
+ print(
33
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
34
+ )
35
 
36
  # Perform hybrid search
37
+ print(
38
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
39
+ )
examples/vram_management_demo.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm import ollama_model_complete, ollama_embedding
5
+ from lightrag.utils import EmbeddingFunc
6
+
7
+ # Working directory and the directory path for text files
8
+ WORKING_DIR = "./dickens"
9
+ TEXT_FILES_DIR = "/llm/mt"
10
+
11
+ # Create the working directory if it doesn't exist
12
+ if not os.path.exists(WORKING_DIR):
13
+ os.mkdir(WORKING_DIR)
14
+
15
+ # Initialize LightRAG
16
+ rag = LightRAG(
17
+ working_dir=WORKING_DIR,
18
+ llm_model_func=ollama_model_complete,
19
+ llm_model_name="qwen2.5:3b-instruct-max-context",
20
+ embedding_func=EmbeddingFunc(
21
+ embedding_dim=768,
22
+ max_token_size=8192,
23
+ func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
24
+ ),
25
+ )
26
+
27
+ # Read all .txt files from the TEXT_FILES_DIR directory
28
+ texts = []
29
+ for filename in os.listdir(TEXT_FILES_DIR):
30
+ if filename.endswith('.txt'):
31
+ file_path = os.path.join(TEXT_FILES_DIR, filename)
32
+ with open(file_path, 'r', encoding='utf-8') as file:
33
+ texts.append(file.read())
34
+
35
+ # Batch insert texts into LightRAG with a retry mechanism
36
+ def insert_texts_with_retry(rag, texts, retries=3, delay=5):
37
+ for _ in range(retries):
38
+ try:
39
+ rag.insert(texts)
40
+ return
41
+ except Exception as e:
42
+ print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...")
43
+ time.sleep(delay)
44
+ raise RuntimeError("Failed to insert texts after multiple retries.")
45
+
46
+ insert_texts_with_retry(rag, texts)
47
+
48
+ # Perform different types of queries and handle potential errors
49
+ try:
50
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
51
+ except Exception as e:
52
+ print(f"Error performing naive search: {e}")
53
+
54
+ try:
55
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
56
+ except Exception as e:
57
+ print(f"Error performing local search: {e}")
58
+
59
+ try:
60
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
61
+ except Exception as e:
62
+ print(f"Error performing global search: {e}")
63
+
64
+ try:
65
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
66
+ except Exception as e:
67
+ print(f"Error performing hybrid search: {e}")
68
+
69
+ # Function to clear VRAM resources
70
+ def clear_vram():
71
+ os.system("sudo nvidia-smi --gpu-reset")
72
+
73
+ # Regularly clear VRAM to prevent overflow
74
+ clear_vram_interval = 3600 # Clear once every hour
75
+ start_time = time.time()
76
+
77
+ while True:
78
+ current_time = time.time()
79
+ if current_time - start_time > clear_vram_interval:
80
+ clear_vram()
81
+ start_time = current_time
82
+ time.sleep(60) # Check the time every minute
lightrag/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- from .lightrag import LightRAG, QueryParam
2
 
3
- __version__ = "0.0.6"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
 
1
+ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
2
 
3
+ __version__ = "0.0.7"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
lightrag/base.py CHANGED
@@ -12,15 +12,16 @@ TextChunkSchema = TypedDict(
12
 
13
  T = TypeVar("T")
14
 
 
15
  @dataclass
16
  class QueryParam:
17
  mode: Literal["local", "global", "hybrid", "naive"] = "global"
18
  only_need_context: bool = False
19
  response_type: str = "Multiple Paragraphs"
20
  top_k: int = 60
21
- max_token_for_text_unit: int = 4000
22
  max_token_for_global_context: int = 4000
23
- max_token_for_local_context: int = 4000
24
 
25
 
26
  @dataclass
@@ -36,6 +37,7 @@ class StorageNameSpace:
36
  """commit the storage operations after querying"""
37
  pass
38
 
 
39
  @dataclass
40
  class BaseVectorStorage(StorageNameSpace):
41
  embedding_func: EmbeddingFunc
@@ -50,6 +52,7 @@ class BaseVectorStorage(StorageNameSpace):
50
  """
51
  raise NotImplementedError
52
 
 
53
  @dataclass
54
  class BaseKVStorage(Generic[T], StorageNameSpace):
55
  async def all_keys(self) -> list[str]:
@@ -72,7 +75,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
72
 
73
  async def drop(self):
74
  raise NotImplementedError
75
-
76
 
77
  @dataclass
78
  class BaseGraphStorage(StorageNameSpace):
@@ -113,4 +116,4 @@ class BaseGraphStorage(StorageNameSpace):
113
  raise NotImplementedError
114
 
115
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
116
- raise NotImplementedError("Node embedding is not used in lightrag.")
 
12
 
13
  T = TypeVar("T")
14
 
15
+
16
  @dataclass
17
  class QueryParam:
18
  mode: Literal["local", "global", "hybrid", "naive"] = "global"
19
  only_need_context: bool = False
20
  response_type: str = "Multiple Paragraphs"
21
  top_k: int = 60
22
+ max_token_for_text_unit: int = 4000
23
  max_token_for_global_context: int = 4000
24
+ max_token_for_local_context: int = 4000
25
 
26
 
27
  @dataclass
 
37
  """commit the storage operations after querying"""
38
  pass
39
 
40
+
41
  @dataclass
42
  class BaseVectorStorage(StorageNameSpace):
43
  embedding_func: EmbeddingFunc
 
52
  """
53
  raise NotImplementedError
54
 
55
+
56
  @dataclass
57
  class BaseKVStorage(Generic[T], StorageNameSpace):
58
  async def all_keys(self) -> list[str]:
 
75
 
76
  async def drop(self):
77
  raise NotImplementedError
78
+
79
 
80
  @dataclass
81
  class BaseGraphStorage(StorageNameSpace):
 
116
  raise NotImplementedError
117
 
118
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
119
+ raise NotImplementedError("Node embedding is not used in lightrag.")
lightrag/lightrag.py CHANGED
@@ -3,10 +3,12 @@ import os
3
  from dataclasses import asdict, dataclass, field
4
  from datetime import datetime
5
  from functools import partial
6
- from typing import Type, cast, Any
7
- from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
8
 
9
- from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
 
 
 
10
  from .operate import (
11
  chunking_by_token_size,
12
  extract_entities,
@@ -37,6 +39,7 @@ from .base import (
37
  QueryParam,
38
  )
39
 
 
40
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
41
  try:
42
  loop = asyncio.get_running_loop()
@@ -69,7 +72,6 @@ class LightRAG:
69
  "dimensions": 1536,
70
  "num_walks": 10,
71
  "walk_length": 40,
72
- "num_walks": 10,
73
  "window_size": 2,
74
  "iterations": 3,
75
  "random_seed": 3,
@@ -77,13 +79,13 @@ class LightRAG:
77
  )
78
 
79
  # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
80
- embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)
81
  embedding_batch_num: int = 32
82
  embedding_func_max_async: int = 16
83
 
84
  # LLM
85
- llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
86
- llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
87
  llm_model_max_token_size: int = 32768
88
  llm_model_max_async: int = 16
89
  llm_model_kwargs: dict = field(default_factory=dict)
@@ -99,11 +101,11 @@ class LightRAG:
99
  addon_params: dict = field(default_factory=dict)
100
  convert_response_to_json_func: callable = convert_response_to_json
101
 
102
- def __post_init__(self):
103
  log_file = os.path.join(self.working_dir, "lightrag.log")
104
  set_logger(log_file)
105
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
106
-
107
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
108
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
109
 
@@ -134,30 +136,24 @@ class LightRAG:
134
  self.embedding_func
135
  )
136
 
137
- self.entities_vdb = (
138
- self.vector_db_storage_cls(
139
- namespace="entities",
140
- global_config=asdict(self),
141
- embedding_func=self.embedding_func,
142
- meta_fields={"entity_name"}
143
- )
144
  )
145
- self.relationships_vdb = (
146
- self.vector_db_storage_cls(
147
- namespace="relationships",
148
- global_config=asdict(self),
149
- embedding_func=self.embedding_func,
150
- meta_fields={"src_id", "tgt_id"}
151
- )
152
  )
153
- self.chunks_vdb = (
154
- self.vector_db_storage_cls(
155
- namespace="chunks",
156
- global_config=asdict(self),
157
- embedding_func=self.embedding_func,
158
- )
159
  )
160
-
161
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
162
  partial(self.llm_model_func, hashing_kv=self.llm_response_cache, **self.llm_model_kwargs)
163
  )
@@ -178,7 +174,7 @@ class LightRAG:
178
  _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
179
  new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
180
  if not len(new_docs):
181
- logger.warning(f"All docs are already in the storage")
182
  return
183
  logger.info(f"[New Docs] inserting {len(new_docs)} docs")
184
 
@@ -204,7 +200,7 @@ class LightRAG:
204
  k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
205
  }
206
  if not len(inserting_chunks):
207
- logger.warning(f"All chunks are already in the storage")
208
  return
209
  logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
210
 
@@ -247,7 +243,7 @@ class LightRAG:
247
  def query(self, query: str, param: QueryParam = QueryParam()):
248
  loop = always_get_an_event_loop()
249
  return loop.run_until_complete(self.aquery(query, param))
250
-
251
  async def aquery(self, query: str, param: QueryParam = QueryParam()):
252
  if param.mode == "local":
253
  response = await local_query(
@@ -291,7 +287,6 @@ class LightRAG:
291
  raise ValueError(f"Unknown mode {param.mode}")
292
  await self._query_done()
293
  return response
294
-
295
 
296
  async def _query_done(self):
297
  tasks = []
@@ -300,5 +295,3 @@ class LightRAG:
300
  continue
301
  tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
302
  await asyncio.gather(*tasks)
303
-
304
-
 
3
  from dataclasses import asdict, dataclass, field
4
  from datetime import datetime
5
  from functools import partial
6
+ from typing import Type, cast
 
7
 
8
+ from .llm import (
9
+ gpt_4o_mini_complete,
10
+ openai_embedding,
11
+ )
12
  from .operate import (
13
  chunking_by_token_size,
14
  extract_entities,
 
39
  QueryParam,
40
  )
41
 
42
+
43
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
44
  try:
45
  loop = asyncio.get_running_loop()
 
72
  "dimensions": 1536,
73
  "num_walks": 10,
74
  "walk_length": 40,
 
75
  "window_size": 2,
76
  "iterations": 3,
77
  "random_seed": 3,
 
79
  )
80
 
81
  # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
82
+ embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
83
  embedding_batch_num: int = 32
84
  embedding_func_max_async: int = 16
85
 
86
  # LLM
87
+ llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
88
+ llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
89
  llm_model_max_token_size: int = 32768
90
  llm_model_max_async: int = 16
91
  llm_model_kwargs: dict = field(default_factory=dict)
 
101
  addon_params: dict = field(default_factory=dict)
102
  convert_response_to_json_func: callable = convert_response_to_json
103
 
104
+ def __post_init__(self):
105
  log_file = os.path.join(self.working_dir, "lightrag.log")
106
  set_logger(log_file)
107
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
108
+
109
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
110
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
111
 
 
136
  self.embedding_func
137
  )
138
 
139
+ self.entities_vdb = self.vector_db_storage_cls(
140
+ namespace="entities",
141
+ global_config=asdict(self),
142
+ embedding_func=self.embedding_func,
143
+ meta_fields={"entity_name"},
 
 
144
  )
145
+ self.relationships_vdb = self.vector_db_storage_cls(
146
+ namespace="relationships",
147
+ global_config=asdict(self),
148
+ embedding_func=self.embedding_func,
149
+ meta_fields={"src_id", "tgt_id"},
 
 
150
  )
151
+ self.chunks_vdb = self.vector_db_storage_cls(
152
+ namespace="chunks",
153
+ global_config=asdict(self),
154
+ embedding_func=self.embedding_func,
 
 
155
  )
156
+
157
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
158
  partial(self.llm_model_func, hashing_kv=self.llm_response_cache, **self.llm_model_kwargs)
159
  )
 
174
  _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
175
  new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
176
  if not len(new_docs):
177
+ logger.warning("All docs are already in the storage")
178
  return
179
  logger.info(f"[New Docs] inserting {len(new_docs)} docs")
180
 
 
200
  k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
201
  }
202
  if not len(inserting_chunks):
203
+ logger.warning("All chunks are already in the storage")
204
  return
205
  logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
206
 
 
243
  def query(self, query: str, param: QueryParam = QueryParam()):
244
  loop = always_get_an_event_loop()
245
  return loop.run_until_complete(self.aquery(query, param))
246
+
247
  async def aquery(self, query: str, param: QueryParam = QueryParam()):
248
  if param.mode == "local":
249
  response = await local_query(
 
287
  raise ValueError(f"Unknown mode {param.mode}")
288
  await self._query_done()
289
  return response
 
290
 
291
  async def _query_done(self):
292
  tasks = []
 
295
  continue
296
  tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
297
  await asyncio.gather(*tasks)
 
 
lightrag/llm.py CHANGED
@@ -1,4 +1,7 @@
1
  import os
 
 
 
2
  import numpy as np
3
  import ollama
4
  from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
@@ -8,24 +11,34 @@ from tenacity import (
8
  wait_exponential,
9
  retry_if_exception_type,
10
  )
11
- from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
12
  import torch
13
  from .base import BaseKVStorage
14
  from .utils import compute_args_hash, wrap_embedding_func_with_attrs
15
- import copy
16
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
 
17
  @retry(
18
  stop=stop_after_attempt(3),
19
  wait=wait_exponential(multiplier=1, min=4, max=10),
20
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
21
  )
22
  async def openai_complete_if_cache(
23
- model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, **kwargs
 
 
 
 
 
 
24
  ) -> str:
25
  if api_key:
26
  os.environ["OPENAI_API_KEY"] = api_key
27
 
28
- openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
 
 
29
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
30
  messages = []
31
  if system_prompt:
@@ -48,15 +61,105 @@ async def openai_complete_if_cache(
48
  )
49
  return response.choices[0].message.content
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  async def hf_model_if_cache(
52
  model, prompt, system_prompt=None, history_messages=[], **kwargs
53
  ) -> str:
54
  model_name = model
55
- hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
56
- if hf_tokenizer.pad_token == None:
57
  # print("use eos token")
58
  hf_tokenizer.pad_token = hf_tokenizer.eos_token
59
- hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
60
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
61
  messages = []
62
  if system_prompt:
@@ -69,30 +172,51 @@ async def hf_model_if_cache(
69
  if_cache_return = await hashing_kv.get_by_id(args_hash)
70
  if if_cache_return is not None:
71
  return if_cache_return["return"]
72
- input_prompt = ''
73
  try:
74
- input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
75
- except:
 
 
76
  try:
77
  ori_message = copy.deepcopy(messages)
78
- if messages[0]['role'] == "system":
79
- messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content']
 
 
 
 
 
80
  messages = messages[1:]
81
- input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
82
- except:
 
 
83
  len_message = len(ori_message)
84
  for msgid in range(len_message):
85
- input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n'
86
-
87
- input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
88
- output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
90
  if hashing_kv is not None:
91
- await hashing_kv.upsert(
92
- {args_hash: {"return": response_text, "model": model}}
93
- )
94
  return response_text
95
 
 
96
  async def ollama_model_if_cache(
97
  model, prompt, system_prompt=None, history_messages=[], **kwargs
98
  ) -> str:
@@ -124,6 +248,7 @@ async def ollama_model_if_cache(
124
 
125
  return result
126
 
 
127
  async def gpt_4o_complete(
128
  prompt, system_prompt=None, history_messages=[], **kwargs
129
  ) -> str:
@@ -147,10 +272,23 @@ async def gpt_4o_mini_complete(
147
  **kwargs,
148
  )
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  async def hf_model_complete(
151
  prompt, system_prompt=None, history_messages=[], **kwargs
152
  ) -> str:
153
- model_name = kwargs['hashing_kv'].global_config['llm_model_name']
154
  return await hf_model_if_cache(
155
  model_name,
156
  prompt,
@@ -159,10 +297,11 @@ async def hf_model_complete(
159
  **kwargs,
160
  )
161
 
 
162
  async def ollama_model_complete(
163
  prompt, system_prompt=None, history_messages=[], **kwargs
164
  ) -> str:
165
- model_name = kwargs['hashing_kv'].global_config['llm_model_name']
166
  return await ollama_model_if_cache(
167
  model_name,
168
  prompt,
@@ -171,30 +310,113 @@ async def ollama_model_complete(
171
  **kwargs,
172
  )
173
 
 
174
  @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
175
  @retry(
176
  stop=stop_after_attempt(3),
177
  wait=wait_exponential(multiplier=1, min=4, max=10),
178
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
179
  )
180
- async def openai_embedding(texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None) -> np.ndarray:
 
 
 
 
 
181
  if api_key:
182
  os.environ["OPENAI_API_KEY"] = api_key
183
 
184
- openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
 
 
185
  response = await openai_async_client.embeddings.create(
186
  model=model, input=texts, encoding_format="float"
187
  )
188
  return np.array([dp.embedding for dp in response.data])
189
 
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
192
- input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
 
 
193
  with torch.no_grad():
194
  outputs = embed_model(input_ids)
195
  embeddings = outputs.last_hidden_state.mean(dim=1)
196
  return embeddings.detach().numpy()
197
 
 
198
  async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
199
  embed_text = []
200
  ollama_client = ollama.Client(**kwargs)
@@ -204,11 +426,12 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra
204
 
205
  return embed_text
206
 
 
207
  if __name__ == "__main__":
208
  import asyncio
209
 
210
  async def main():
211
- result = await gpt_4o_mini_complete('How are you?')
212
  print(result)
213
 
214
  asyncio.run(main())
 
1
  import os
2
+ import copy
3
+ import json
4
+ import aioboto3
5
  import numpy as np
6
  import ollama
7
  from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
 
11
  wait_exponential,
12
  retry_if_exception_type,
13
  )
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
  import torch
16
  from .base import BaseKVStorage
17
  from .utils import compute_args_hash, wrap_embedding_func_with_attrs
18
+
19
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
+
21
+
22
  @retry(
23
  stop=stop_after_attempt(3),
24
  wait=wait_exponential(multiplier=1, min=4, max=10),
25
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
26
  )
27
  async def openai_complete_if_cache(
28
+ model,
29
+ prompt,
30
+ system_prompt=None,
31
+ history_messages=[],
32
+ base_url=None,
33
+ api_key=None,
34
+ **kwargs,
35
  ) -> str:
36
  if api_key:
37
  os.environ["OPENAI_API_KEY"] = api_key
38
 
39
+ openai_async_client = (
40
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
41
+ )
42
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
43
  messages = []
44
  if system_prompt:
 
61
  )
62
  return response.choices[0].message.content
63
 
64
+
65
+ class BedrockError(Exception):
66
+ """Generic error for issues related to Amazon Bedrock"""
67
+
68
+
69
+ @retry(
70
+ stop=stop_after_attempt(5),
71
+ wait=wait_exponential(multiplier=1, max=60),
72
+ retry=retry_if_exception_type((BedrockError)),
73
+ )
74
+ async def bedrock_complete_if_cache(
75
+ model,
76
+ prompt,
77
+ system_prompt=None,
78
+ history_messages=[],
79
+ aws_access_key_id=None,
80
+ aws_secret_access_key=None,
81
+ aws_session_token=None,
82
+ **kwargs,
83
+ ) -> str:
84
+ os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
85
+ "AWS_ACCESS_KEY_ID", aws_access_key_id
86
+ )
87
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
88
+ "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
89
+ )
90
+ os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
91
+ "AWS_SESSION_TOKEN", aws_session_token
92
+ )
93
+
94
+ # Fix message history format
95
+ messages = []
96
+ for history_message in history_messages:
97
+ message = copy.copy(history_message)
98
+ message["content"] = [{"text": message["content"]}]
99
+ messages.append(message)
100
+
101
+ # Add user prompt
102
+ messages.append({"role": "user", "content": [{"text": prompt}]})
103
+
104
+ # Initialize Converse API arguments
105
+ args = {"modelId": model, "messages": messages}
106
+
107
+ # Define system prompt
108
+ if system_prompt:
109
+ args["system"] = [{"text": system_prompt}]
110
+
111
+ # Map and set up inference parameters
112
+ inference_params_map = {
113
+ "max_tokens": "maxTokens",
114
+ "top_p": "topP",
115
+ "stop_sequences": "stopSequences",
116
+ }
117
+ if inference_params := list(
118
+ set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
119
+ ):
120
+ args["inferenceConfig"] = {}
121
+ for param in inference_params:
122
+ args["inferenceConfig"][inference_params_map.get(param, param)] = (
123
+ kwargs.pop(param)
124
+ )
125
+
126
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
127
+ if hashing_kv is not None:
128
+ args_hash = compute_args_hash(model, messages)
129
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
130
+ if if_cache_return is not None:
131
+ return if_cache_return["return"]
132
+
133
+ # Call model via Converse API
134
+ session = aioboto3.Session()
135
+ async with session.client("bedrock-runtime") as bedrock_async_client:
136
+ try:
137
+ response = await bedrock_async_client.converse(**args, **kwargs)
138
+ except Exception as e:
139
+ raise BedrockError(e)
140
+
141
+ if hashing_kv is not None:
142
+ await hashing_kv.upsert(
143
+ {
144
+ args_hash: {
145
+ "return": response["output"]["message"]["content"][0]["text"],
146
+ "model": model,
147
+ }
148
+ }
149
+ )
150
+
151
+ return response["output"]["message"]["content"][0]["text"]
152
+
153
+
154
  async def hf_model_if_cache(
155
  model, prompt, system_prompt=None, history_messages=[], **kwargs
156
  ) -> str:
157
  model_name = model
158
+ hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
159
+ if hf_tokenizer.pad_token is None:
160
  # print("use eos token")
161
  hf_tokenizer.pad_token = hf_tokenizer.eos_token
162
+ hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
163
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
164
  messages = []
165
  if system_prompt:
 
172
  if_cache_return = await hashing_kv.get_by_id(args_hash)
173
  if if_cache_return is not None:
174
  return if_cache_return["return"]
175
+ input_prompt = ""
176
  try:
177
+ input_prompt = hf_tokenizer.apply_chat_template(
178
+ messages, tokenize=False, add_generation_prompt=True
179
+ )
180
+ except Exception:
181
  try:
182
  ori_message = copy.deepcopy(messages)
183
+ if messages[0]["role"] == "system":
184
+ messages[1]["content"] = (
185
+ "<system>"
186
+ + messages[0]["content"]
187
+ + "</system>\n"
188
+ + messages[1]["content"]
189
+ )
190
  messages = messages[1:]
191
+ input_prompt = hf_tokenizer.apply_chat_template(
192
+ messages, tokenize=False, add_generation_prompt=True
193
+ )
194
+ except Exception:
195
  len_message = len(ori_message)
196
  for msgid in range(len_message):
197
+ input_prompt = (
198
+ input_prompt
199
+ + "<"
200
+ + ori_message[msgid]["role"]
201
+ + ">"
202
+ + ori_message[msgid]["content"]
203
+ + "</"
204
+ + ori_message[msgid]["role"]
205
+ + ">\n"
206
+ )
207
+
208
+ input_ids = hf_tokenizer(
209
+ input_prompt, return_tensors="pt", padding=True, truncation=True
210
+ ).to("cuda")
211
+ output = hf_model.generate(
212
+ **input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
213
+ )
214
  response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
215
  if hashing_kv is not None:
216
+ await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
 
 
217
  return response_text
218
 
219
+
220
  async def ollama_model_if_cache(
221
  model, prompt, system_prompt=None, history_messages=[], **kwargs
222
  ) -> str:
 
248
 
249
  return result
250
 
251
+
252
  async def gpt_4o_complete(
253
  prompt, system_prompt=None, history_messages=[], **kwargs
254
  ) -> str:
 
272
  **kwargs,
273
  )
274
 
275
+
276
+ async def bedrock_complete(
277
+ prompt, system_prompt=None, history_messages=[], **kwargs
278
+ ) -> str:
279
+ return await bedrock_complete_if_cache(
280
+ "anthropic.claude-3-haiku-20240307-v1:0",
281
+ prompt,
282
+ system_prompt=system_prompt,
283
+ history_messages=history_messages,
284
+ **kwargs,
285
+ )
286
+
287
+
288
  async def hf_model_complete(
289
  prompt, system_prompt=None, history_messages=[], **kwargs
290
  ) -> str:
291
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
292
  return await hf_model_if_cache(
293
  model_name,
294
  prompt,
 
297
  **kwargs,
298
  )
299
 
300
+
301
  async def ollama_model_complete(
302
  prompt, system_prompt=None, history_messages=[], **kwargs
303
  ) -> str:
304
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
305
  return await ollama_model_if_cache(
306
  model_name,
307
  prompt,
 
310
  **kwargs,
311
  )
312
 
313
+
314
  @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
315
  @retry(
316
  stop=stop_after_attempt(3),
317
  wait=wait_exponential(multiplier=1, min=4, max=10),
318
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
319
  )
320
+ async def openai_embedding(
321
+ texts: list[str],
322
+ model: str = "text-embedding-3-small",
323
+ base_url: str = None,
324
+ api_key: str = None,
325
+ ) -> np.ndarray:
326
  if api_key:
327
  os.environ["OPENAI_API_KEY"] = api_key
328
 
329
+ openai_async_client = (
330
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
331
+ )
332
  response = await openai_async_client.embeddings.create(
333
  model=model, input=texts, encoding_format="float"
334
  )
335
  return np.array([dp.embedding for dp in response.data])
336
 
337
 
338
+ # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
339
+ # @retry(
340
+ # stop=stop_after_attempt(3),
341
+ # wait=wait_exponential(multiplier=1, min=4, max=10),
342
+ # retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
343
+ # )
344
+ async def bedrock_embedding(
345
+ texts: list[str],
346
+ model: str = "amazon.titan-embed-text-v2:0",
347
+ aws_access_key_id=None,
348
+ aws_secret_access_key=None,
349
+ aws_session_token=None,
350
+ ) -> np.ndarray:
351
+ os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
352
+ "AWS_ACCESS_KEY_ID", aws_access_key_id
353
+ )
354
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
355
+ "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
356
+ )
357
+ os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
358
+ "AWS_SESSION_TOKEN", aws_session_token
359
+ )
360
+
361
+ session = aioboto3.Session()
362
+ async with session.client("bedrock-runtime") as bedrock_async_client:
363
+ if (model_provider := model.split(".")[0]) == "amazon":
364
+ embed_texts = []
365
+ for text in texts:
366
+ if "v2" in model:
367
+ body = json.dumps(
368
+ {
369
+ "inputText": text,
370
+ # 'dimensions': embedding_dim,
371
+ "embeddingTypes": ["float"],
372
+ }
373
+ )
374
+ elif "v1" in model:
375
+ body = json.dumps({"inputText": text})
376
+ else:
377
+ raise ValueError(f"Model {model} is not supported!")
378
+
379
+ response = await bedrock_async_client.invoke_model(
380
+ modelId=model,
381
+ body=body,
382
+ accept="application/json",
383
+ contentType="application/json",
384
+ )
385
+
386
+ response_body = await response.get("body").json()
387
+
388
+ embed_texts.append(response_body["embedding"])
389
+ elif model_provider == "cohere":
390
+ body = json.dumps(
391
+ {"texts": texts, "input_type": "search_document", "truncate": "NONE"}
392
+ )
393
+
394
+ response = await bedrock_async_client.invoke_model(
395
+ model=model,
396
+ body=body,
397
+ accept="application/json",
398
+ contentType="application/json",
399
+ )
400
+
401
+ response_body = json.loads(response.get("body").read())
402
+
403
+ embed_texts = response_body["embeddings"]
404
+ else:
405
+ raise ValueError(f"Model provider '{model_provider}' is not supported!")
406
+
407
+ return np.array(embed_texts)
408
+
409
+
410
  async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
411
+ input_ids = tokenizer(
412
+ texts, return_tensors="pt", padding=True, truncation=True
413
+ ).input_ids
414
  with torch.no_grad():
415
  outputs = embed_model(input_ids)
416
  embeddings = outputs.last_hidden_state.mean(dim=1)
417
  return embeddings.detach().numpy()
418
 
419
+
420
  async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
421
  embed_text = []
422
  ollama_client = ollama.Client(**kwargs)
 
426
 
427
  return embed_text
428
 
429
+
430
  if __name__ == "__main__":
431
  import asyncio
432
 
433
  async def main():
434
+ result = await gpt_4o_mini_complete("How are you?")
435
  print(result)
436
 
437
  asyncio.run(main())
lightrag/operate.py CHANGED
@@ -25,6 +25,7 @@ from .base import (
25
  )
26
  from .prompt import GRAPH_FIELD_SEP, PROMPTS
27
 
 
28
  def chunking_by_token_size(
29
  content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
30
  ):
@@ -45,6 +46,7 @@ def chunking_by_token_size(
45
  )
46
  return results
47
 
 
48
  async def _handle_entity_relation_summary(
49
  entity_or_relation_name: str,
50
  description: str,
@@ -229,9 +231,10 @@ async def _merge_edges_then_upsert(
229
  description=description,
230
  keywords=keywords,
231
  )
232
-
233
  return edge_data
234
 
 
235
  async def extract_entities(
236
  chunks: dict[str, TextChunkSchema],
237
  knwoledge_graph_inst: BaseGraphStorage,
@@ -352,7 +355,9 @@ async def extract_entities(
352
  logger.warning("Didn't extract any entities, maybe your LLM is not working")
353
  return None
354
  if not len(all_relationships_data):
355
- logger.warning("Didn't extract any relationships, maybe your LLM is not working")
 
 
356
  return None
357
 
358
  if entity_vdb is not None:
@@ -370,7 +375,10 @@ async def extract_entities(
370
  compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
371
  "src_id": dp["src_id"],
372
  "tgt_id": dp["tgt_id"],
373
- "content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"],
 
 
 
374
  }
375
  for dp in all_relationships_data
376
  }
@@ -378,6 +386,7 @@ async def extract_entities(
378
 
379
  return knwoledge_graph_inst
380
 
 
381
  async def local_query(
382
  query,
383
  knowledge_graph_inst: BaseGraphStorage,
@@ -393,19 +402,24 @@ async def local_query(
393
  kw_prompt_temp = PROMPTS["keywords_extraction"]
394
  kw_prompt = kw_prompt_temp.format(query=query)
395
  result = await use_model_func(kw_prompt)
396
-
397
  try:
398
  keywords_data = json.loads(result)
399
  keywords = keywords_data.get("low_level_keywords", [])
400
- keywords = ', '.join(keywords)
401
- except json.JSONDecodeError as e:
402
  try:
403
- result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
404
- result = '{' + result.split('{')[1].split('}')[0] + '}'
 
 
 
 
 
405
 
406
  keywords_data = json.loads(result)
407
  keywords = keywords_data.get("low_level_keywords", [])
408
- keywords = ', '.join(keywords)
409
  # Handle parsing error
410
  except json.JSONDecodeError as e:
411
  print(f"JSON parsing error: {e}")
@@ -430,11 +444,20 @@ async def local_query(
430
  query,
431
  system_prompt=sys_prompt,
432
  )
433
- if len(response)>len(sys_prompt):
434
- response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
435
-
 
 
 
 
 
 
 
 
436
  return response
437
 
 
438
  async def _build_local_query_context(
439
  query,
440
  knowledge_graph_inst: BaseGraphStorage,
@@ -516,6 +539,7 @@ async def _build_local_query_context(
516
  ```
517
  """
518
 
 
519
  async def _find_most_related_text_unit_from_entities(
520
  node_datas: list[dict],
521
  query_param: QueryParam,
@@ -576,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
576
  all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
577
  return all_text_units
578
 
 
579
  async def _find_most_related_edges_from_entities(
580
  node_datas: list[dict],
581
  query_param: QueryParam,
@@ -609,6 +634,7 @@ async def _find_most_related_edges_from_entities(
609
  )
610
  return all_edges_data
611
 
 
612
  async def global_query(
613
  query,
614
  knowledge_graph_inst: BaseGraphStorage,
@@ -624,20 +650,25 @@ async def global_query(
624
  kw_prompt_temp = PROMPTS["keywords_extraction"]
625
  kw_prompt = kw_prompt_temp.format(query=query)
626
  result = await use_model_func(kw_prompt)
627
-
628
  try:
629
  keywords_data = json.loads(result)
630
  keywords = keywords_data.get("high_level_keywords", [])
631
- keywords = ', '.join(keywords)
632
- except json.JSONDecodeError as e:
633
  try:
634
- result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
635
- result = '{' + result.split('{')[1].split('}')[0] + '}'
 
 
 
 
 
636
 
637
  keywords_data = json.loads(result)
638
  keywords = keywords_data.get("high_level_keywords", [])
639
- keywords = ', '.join(keywords)
640
-
641
  except json.JSONDecodeError as e:
642
  # Handle parsing error
643
  print(f"JSON parsing error: {e}")
@@ -651,12 +682,12 @@ async def global_query(
651
  text_chunks_db,
652
  query_param,
653
  )
654
-
655
  if query_param.only_need_context:
656
  return context
657
  if context is None:
658
  return PROMPTS["fail_response"]
659
-
660
  sys_prompt_temp = PROMPTS["rag_response"]
661
  sys_prompt = sys_prompt_temp.format(
662
  context_data=context, response_type=query_param.response_type
@@ -665,11 +696,20 @@ async def global_query(
665
  query,
666
  system_prompt=sys_prompt,
667
  )
668
- if len(response)>len(sys_prompt):
669
- response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
670
-
 
 
 
 
 
 
 
 
671
  return response
672
 
 
673
  async def _build_global_query_context(
674
  keywords,
675
  knowledge_graph_inst: BaseGraphStorage,
@@ -679,14 +719,14 @@ async def _build_global_query_context(
679
  query_param: QueryParam,
680
  ):
681
  results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
682
-
683
  if not len(results):
684
  return None
685
-
686
  edge_datas = await asyncio.gather(
687
  *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
688
  )
689
-
690
  if not all([n is not None for n in edge_datas]):
691
  logger.warning("Some edges are missing, maybe the storage is damaged")
692
  edge_degree = await asyncio.gather(
@@ -765,6 +805,7 @@ async def _build_global_query_context(
765
  ```
766
  """
767
 
 
768
  async def _find_most_related_entities_from_relationships(
769
  edge_datas: list[dict],
770
  query_param: QueryParam,
@@ -774,7 +815,7 @@ async def _find_most_related_entities_from_relationships(
774
  for e in edge_datas:
775
  entity_names.add(e["src_id"])
776
  entity_names.add(e["tgt_id"])
777
-
778
  node_datas = await asyncio.gather(
779
  *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
780
  )
@@ -795,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
795
 
796
  return node_datas
797
 
 
798
  async def _find_related_text_unit_from_relationships(
799
  edge_datas: list[dict],
800
  query_param: QueryParam,
801
  text_chunks_db: BaseKVStorage[TextChunkSchema],
802
  knowledge_graph_inst: BaseGraphStorage,
803
  ):
804
-
805
  text_units = [
806
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
807
  for dp in edge_datas
@@ -816,15 +857,13 @@ async def _find_related_text_unit_from_relationships(
816
  "data": await text_chunks_db.get_by_id(c_id),
817
  "order": index,
818
  }
819
-
820
  if any([v is None for v in all_text_units_lookup.values()]):
821
  logger.warning("Text chunks are missing, maybe the storage is damaged")
822
  all_text_units = [
823
  {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
824
  ]
825
- all_text_units = sorted(
826
- all_text_units, key=lambda x: x["order"]
827
- )
828
  all_text_units = truncate_list_by_token_size(
829
  all_text_units,
830
  key=lambda x: x["data"]["content"],
@@ -834,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
834
 
835
  return all_text_units
836
 
 
837
  async def hybrid_query(
838
  query,
839
  knowledge_graph_inst: BaseGraphStorage,
@@ -849,24 +889,29 @@ async def hybrid_query(
849
 
850
  kw_prompt_temp = PROMPTS["keywords_extraction"]
851
  kw_prompt = kw_prompt_temp.format(query=query)
852
-
853
  result = await use_model_func(kw_prompt)
854
  try:
855
  keywords_data = json.loads(result)
856
  hl_keywords = keywords_data.get("high_level_keywords", [])
857
  ll_keywords = keywords_data.get("low_level_keywords", [])
858
- hl_keywords = ', '.join(hl_keywords)
859
- ll_keywords = ', '.join(ll_keywords)
860
- except json.JSONDecodeError as e:
861
  try:
862
- result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
863
- result = '{' + result.split('{')[1].split('}')[0] + '}'
 
 
 
 
 
864
 
865
  keywords_data = json.loads(result)
866
  hl_keywords = keywords_data.get("high_level_keywords", [])
867
  ll_keywords = keywords_data.get("low_level_keywords", [])
868
- hl_keywords = ', '.join(hl_keywords)
869
- ll_keywords = ', '.join(ll_keywords)
870
  # Handle parsing error
871
  except json.JSONDecodeError as e:
872
  print(f"JSON parsing error: {e}")
@@ -897,7 +942,7 @@ async def hybrid_query(
897
  return context
898
  if context is None:
899
  return PROMPTS["fail_response"]
900
-
901
  sys_prompt_temp = PROMPTS["rag_response"]
902
  sys_prompt = sys_prompt_temp.format(
903
  context_data=context, response_type=query_param.response_type
@@ -906,53 +951,78 @@ async def hybrid_query(
906
  query,
907
  system_prompt=sys_prompt,
908
  )
909
- if len(response)>len(sys_prompt):
910
- response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
 
 
 
 
 
 
 
 
911
  return response
912
 
 
913
  def combine_contexts(high_level_context, low_level_context):
914
  # Function to extract entities, relationships, and sources from context strings
915
 
916
  def extract_sections(context):
917
- entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
918
- relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
919
- sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
920
-
921
- entities = entities_match.group(1) if entities_match else ''
922
- relationships = relationships_match.group(1) if relationships_match else ''
923
- sources = sources_match.group(1) if sources_match else ''
924
-
 
 
 
 
 
 
925
  return entities, relationships, sources
926
-
927
  # Extract sections from both contexts
928
 
929
- if high_level_context==None:
930
- warnings.warn("High Level context is None. Return empty High entity/relationship/source")
931
- hl_entities, hl_relationships, hl_sources = '','',''
 
 
932
  else:
933
  hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
934
 
935
-
936
- if low_level_context==None:
937
- warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
938
- ll_entities, ll_relationships, ll_sources = '','',''
 
939
  else:
940
  ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
941
 
942
-
943
-
944
  # Combine and deduplicate the entities
945
- combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
946
- combined_entities = '\n'.join(combined_entities_set)
947
-
 
 
948
  # Combine and deduplicate the relationships
949
- combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
950
- combined_relationships = '\n'.join(combined_relationships_set)
951
-
 
 
 
 
 
952
  # Combine and deduplicate the sources
953
- combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
954
- combined_sources = '\n'.join(combined_sources_set)
955
-
 
 
956
  # Format the combined context
957
  return f"""
958
  -----Entities-----
@@ -964,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
964
  {combined_sources}
965
  """
966
 
 
967
  async def naive_query(
968
  query,
969
  chunks_vdb: BaseVectorStorage,
@@ -996,8 +1067,16 @@ async def naive_query(
996
  system_prompt=sys_prompt,
997
  )
998
 
999
- if len(response)>len(sys_prompt):
1000
- response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
1001
-
1002
- return response
 
 
 
 
 
 
 
1003
 
 
 
25
  )
26
  from .prompt import GRAPH_FIELD_SEP, PROMPTS
27
 
28
+
29
  def chunking_by_token_size(
30
  content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
31
  ):
 
46
  )
47
  return results
48
 
49
+
50
  async def _handle_entity_relation_summary(
51
  entity_or_relation_name: str,
52
  description: str,
 
231
  description=description,
232
  keywords=keywords,
233
  )
234
+
235
  return edge_data
236
 
237
+
238
  async def extract_entities(
239
  chunks: dict[str, TextChunkSchema],
240
  knwoledge_graph_inst: BaseGraphStorage,
 
355
  logger.warning("Didn't extract any entities, maybe your LLM is not working")
356
  return None
357
  if not len(all_relationships_data):
358
+ logger.warning(
359
+ "Didn't extract any relationships, maybe your LLM is not working"
360
+ )
361
  return None
362
 
363
  if entity_vdb is not None:
 
375
  compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
376
  "src_id": dp["src_id"],
377
  "tgt_id": dp["tgt_id"],
378
+ "content": dp["keywords"]
379
+ + dp["src_id"]
380
+ + dp["tgt_id"]
381
+ + dp["description"],
382
  }
383
  for dp in all_relationships_data
384
  }
 
386
 
387
  return knwoledge_graph_inst
388
 
389
+
390
  async def local_query(
391
  query,
392
  knowledge_graph_inst: BaseGraphStorage,
 
402
  kw_prompt_temp = PROMPTS["keywords_extraction"]
403
  kw_prompt = kw_prompt_temp.format(query=query)
404
  result = await use_model_func(kw_prompt)
405
+
406
  try:
407
  keywords_data = json.loads(result)
408
  keywords = keywords_data.get("low_level_keywords", [])
409
+ keywords = ", ".join(keywords)
410
+ except json.JSONDecodeError:
411
  try:
412
+ result = (
413
+ result.replace(kw_prompt[:-1], "")
414
+ .replace("user", "")
415
+ .replace("model", "")
416
+ .strip()
417
+ )
418
+ result = "{" + result.split("{")[1].split("}")[0] + "}"
419
 
420
  keywords_data = json.loads(result)
421
  keywords = keywords_data.get("low_level_keywords", [])
422
+ keywords = ", ".join(keywords)
423
  # Handle parsing error
424
  except json.JSONDecodeError as e:
425
  print(f"JSON parsing error: {e}")
 
444
  query,
445
  system_prompt=sys_prompt,
446
  )
447
+ if len(response) > len(sys_prompt):
448
+ response = (
449
+ response.replace(sys_prompt, "")
450
+ .replace("user", "")
451
+ .replace("model", "")
452
+ .replace(query, "")
453
+ .replace("<system>", "")
454
+ .replace("</system>", "")
455
+ .strip()
456
+ )
457
+
458
  return response
459
 
460
+
461
  async def _build_local_query_context(
462
  query,
463
  knowledge_graph_inst: BaseGraphStorage,
 
539
  ```
540
  """
541
 
542
+
543
  async def _find_most_related_text_unit_from_entities(
544
  node_datas: list[dict],
545
  query_param: QueryParam,
 
600
  all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
601
  return all_text_units
602
 
603
+
604
  async def _find_most_related_edges_from_entities(
605
  node_datas: list[dict],
606
  query_param: QueryParam,
 
634
  )
635
  return all_edges_data
636
 
637
+
638
  async def global_query(
639
  query,
640
  knowledge_graph_inst: BaseGraphStorage,
 
650
  kw_prompt_temp = PROMPTS["keywords_extraction"]
651
  kw_prompt = kw_prompt_temp.format(query=query)
652
  result = await use_model_func(kw_prompt)
653
+
654
  try:
655
  keywords_data = json.loads(result)
656
  keywords = keywords_data.get("high_level_keywords", [])
657
+ keywords = ", ".join(keywords)
658
+ except json.JSONDecodeError:
659
  try:
660
+ result = (
661
+ result.replace(kw_prompt[:-1], "")
662
+ .replace("user", "")
663
+ .replace("model", "")
664
+ .strip()
665
+ )
666
+ result = "{" + result.split("{")[1].split("}")[0] + "}"
667
 
668
  keywords_data = json.loads(result)
669
  keywords = keywords_data.get("high_level_keywords", [])
670
+ keywords = ", ".join(keywords)
671
+
672
  except json.JSONDecodeError as e:
673
  # Handle parsing error
674
  print(f"JSON parsing error: {e}")
 
682
  text_chunks_db,
683
  query_param,
684
  )
685
+
686
  if query_param.only_need_context:
687
  return context
688
  if context is None:
689
  return PROMPTS["fail_response"]
690
+
691
  sys_prompt_temp = PROMPTS["rag_response"]
692
  sys_prompt = sys_prompt_temp.format(
693
  context_data=context, response_type=query_param.response_type
 
696
  query,
697
  system_prompt=sys_prompt,
698
  )
699
+ if len(response) > len(sys_prompt):
700
+ response = (
701
+ response.replace(sys_prompt, "")
702
+ .replace("user", "")
703
+ .replace("model", "")
704
+ .replace(query, "")
705
+ .replace("<system>", "")
706
+ .replace("</system>", "")
707
+ .strip()
708
+ )
709
+
710
  return response
711
 
712
+
713
  async def _build_global_query_context(
714
  keywords,
715
  knowledge_graph_inst: BaseGraphStorage,
 
719
  query_param: QueryParam,
720
  ):
721
  results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
722
+
723
  if not len(results):
724
  return None
725
+
726
  edge_datas = await asyncio.gather(
727
  *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
728
  )
729
+
730
  if not all([n is not None for n in edge_datas]):
731
  logger.warning("Some edges are missing, maybe the storage is damaged")
732
  edge_degree = await asyncio.gather(
 
805
  ```
806
  """
807
 
808
+
809
  async def _find_most_related_entities_from_relationships(
810
  edge_datas: list[dict],
811
  query_param: QueryParam,
 
815
  for e in edge_datas:
816
  entity_names.add(e["src_id"])
817
  entity_names.add(e["tgt_id"])
818
+
819
  node_datas = await asyncio.gather(
820
  *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
821
  )
 
836
 
837
  return node_datas
838
 
839
+
840
  async def _find_related_text_unit_from_relationships(
841
  edge_datas: list[dict],
842
  query_param: QueryParam,
843
  text_chunks_db: BaseKVStorage[TextChunkSchema],
844
  knowledge_graph_inst: BaseGraphStorage,
845
  ):
 
846
  text_units = [
847
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
848
  for dp in edge_datas
 
857
  "data": await text_chunks_db.get_by_id(c_id),
858
  "order": index,
859
  }
860
+
861
  if any([v is None for v in all_text_units_lookup.values()]):
862
  logger.warning("Text chunks are missing, maybe the storage is damaged")
863
  all_text_units = [
864
  {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
865
  ]
866
+ all_text_units = sorted(all_text_units, key=lambda x: x["order"])
 
 
867
  all_text_units = truncate_list_by_token_size(
868
  all_text_units,
869
  key=lambda x: x["data"]["content"],
 
873
 
874
  return all_text_units
875
 
876
+
877
  async def hybrid_query(
878
  query,
879
  knowledge_graph_inst: BaseGraphStorage,
 
889
 
890
  kw_prompt_temp = PROMPTS["keywords_extraction"]
891
  kw_prompt = kw_prompt_temp.format(query=query)
892
+
893
  result = await use_model_func(kw_prompt)
894
  try:
895
  keywords_data = json.loads(result)
896
  hl_keywords = keywords_data.get("high_level_keywords", [])
897
  ll_keywords = keywords_data.get("low_level_keywords", [])
898
+ hl_keywords = ", ".join(hl_keywords)
899
+ ll_keywords = ", ".join(ll_keywords)
900
+ except json.JSONDecodeError:
901
  try:
902
+ result = (
903
+ result.replace(kw_prompt[:-1], "")
904
+ .replace("user", "")
905
+ .replace("model", "")
906
+ .strip()
907
+ )
908
+ result = "{" + result.split("{")[1].split("}")[0] + "}"
909
 
910
  keywords_data = json.loads(result)
911
  hl_keywords = keywords_data.get("high_level_keywords", [])
912
  ll_keywords = keywords_data.get("low_level_keywords", [])
913
+ hl_keywords = ", ".join(hl_keywords)
914
+ ll_keywords = ", ".join(ll_keywords)
915
  # Handle parsing error
916
  except json.JSONDecodeError as e:
917
  print(f"JSON parsing error: {e}")
 
942
  return context
943
  if context is None:
944
  return PROMPTS["fail_response"]
945
+
946
  sys_prompt_temp = PROMPTS["rag_response"]
947
  sys_prompt = sys_prompt_temp.format(
948
  context_data=context, response_type=query_param.response_type
 
951
  query,
952
  system_prompt=sys_prompt,
953
  )
954
+ if len(response) > len(sys_prompt):
955
+ response = (
956
+ response.replace(sys_prompt, "")
957
+ .replace("user", "")
958
+ .replace("model", "")
959
+ .replace(query, "")
960
+ .replace("<system>", "")
961
+ .replace("</system>", "")
962
+ .strip()
963
+ )
964
  return response
965
 
966
+
967
  def combine_contexts(high_level_context, low_level_context):
968
  # Function to extract entities, relationships, and sources from context strings
969
 
970
  def extract_sections(context):
971
+ entities_match = re.search(
972
+ r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
973
+ )
974
+ relationships_match = re.search(
975
+ r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
976
+ )
977
+ sources_match = re.search(
978
+ r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
979
+ )
980
+
981
+ entities = entities_match.group(1) if entities_match else ""
982
+ relationships = relationships_match.group(1) if relationships_match else ""
983
+ sources = sources_match.group(1) if sources_match else ""
984
+
985
  return entities, relationships, sources
986
+
987
  # Extract sections from both contexts
988
 
989
+ if high_level_context is None:
990
+ warnings.warn(
991
+ "High Level context is None. Return empty High entity/relationship/source"
992
+ )
993
+ hl_entities, hl_relationships, hl_sources = "", "", ""
994
  else:
995
  hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
996
 
997
+ if low_level_context is None:
998
+ warnings.warn(
999
+ "Low Level context is None. Return empty Low entity/relationship/source"
1000
+ )
1001
+ ll_entities, ll_relationships, ll_sources = "", "", ""
1002
  else:
1003
  ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
1004
 
 
 
1005
  # Combine and deduplicate the entities
1006
+ combined_entities_set = set(
1007
+ filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
1008
+ )
1009
+ combined_entities = "\n".join(combined_entities_set)
1010
+
1011
  # Combine and deduplicate the relationships
1012
+ combined_relationships_set = set(
1013
+ filter(
1014
+ None,
1015
+ hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
1016
+ )
1017
+ )
1018
+ combined_relationships = "\n".join(combined_relationships_set)
1019
+
1020
  # Combine and deduplicate the sources
1021
+ combined_sources_set = set(
1022
+ filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
1023
+ )
1024
+ combined_sources = "\n".join(combined_sources_set)
1025
+
1026
  # Format the combined context
1027
  return f"""
1028
  -----Entities-----
 
1034
  {combined_sources}
1035
  """
1036
 
1037
+
1038
  async def naive_query(
1039
  query,
1040
  chunks_vdb: BaseVectorStorage,
 
1067
  system_prompt=sys_prompt,
1068
  )
1069
 
1070
+ if len(response) > len(sys_prompt):
1071
+ response = (
1072
+ response[len(sys_prompt) :]
1073
+ .replace(sys_prompt, "")
1074
+ .replace("user", "")
1075
+ .replace("model", "")
1076
+ .replace(query, "")
1077
+ .replace("<system>", "")
1078
+ .replace("</system>", "")
1079
+ .strip()
1080
+ )
1081
 
1082
+ return response
lightrag/prompt.py CHANGED
@@ -9,9 +9,7 @@ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
9
 
10
  PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
11
 
12
- PROMPTS[
13
- "entity_extraction"
14
- ] = """-Goal-
15
  Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
16
 
17
  -Steps-
@@ -32,7 +30,7 @@ Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tupl
32
 
33
  3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
34
  Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
35
-
36
  4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
37
 
38
  5. When finished, output {completion_delimiter}
@@ -146,9 +144,7 @@ PROMPTS[
146
 
147
  PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
148
 
149
- PROMPTS[
150
- "rag_response"
151
- ] = """---Role---
152
 
153
  You are a helpful assistant responding to questions about data in the tables provided.
154
 
@@ -163,25 +159,10 @@ Do not include information where the supporting evidence for it is not provided.
163
 
164
  {response_type}
165
 
166
-
167
  ---Data tables---
168
 
169
  {context_data}
170
 
171
-
172
- ---Goal---
173
-
174
- Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
175
-
176
- If you don't know the answer, just say so. Do not make anything up.
177
-
178
- Do not include information where the supporting evidence for it is not provided.
179
-
180
-
181
- ---Target response length and format---
182
-
183
- {response_type}
184
-
185
  Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
186
  """
187
 
@@ -241,9 +222,7 @@ Output:
241
 
242
  """
243
 
244
- PROMPTS[
245
- "naive_rag_response"
246
- ] = """You're a helpful assistant
247
  Below are the knowledge you know:
248
  {content_data}
249
  ---
 
9
 
10
  PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
11
 
12
+ PROMPTS["entity_extraction"] = """-Goal-
 
 
13
  Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
14
 
15
  -Steps-
 
30
 
31
  3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
32
  Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
33
+
34
  4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
35
 
36
  5. When finished, output {completion_delimiter}
 
144
 
145
  PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
146
 
147
+ PROMPTS["rag_response"] = """---Role---
 
 
148
 
149
  You are a helpful assistant responding to questions about data in the tables provided.
150
 
 
159
 
160
  {response_type}
161
 
 
162
  ---Data tables---
163
 
164
  {context_data}
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
167
  """
168
 
 
222
 
223
  """
224
 
225
+ PROMPTS["naive_rag_response"] = """You're a helpful assistant
 
 
226
  Below are the knowledge you know:
227
  {content_data}
228
  ---
lightrag/storage.py CHANGED
@@ -1,16 +1,11 @@
1
  import asyncio
2
  import html
3
- import json
4
  import os
5
- from collections import defaultdict
6
- from dataclasses import dataclass, field
7
  from typing import Any, Union, cast
8
- import pickle
9
- import hnswlib
10
  import networkx as nx
11
  import numpy as np
12
  from nano_vectordb import NanoVectorDB
13
- import xxhash
14
 
15
  from .utils import load_json, logger, write_json
16
  from .base import (
@@ -19,6 +14,7 @@ from .base import (
19
  BaseVectorStorage,
20
  )
21
 
 
22
  @dataclass
23
  class JsonKVStorage(BaseKVStorage):
24
  def __post_init__(self):
@@ -59,12 +55,12 @@ class JsonKVStorage(BaseKVStorage):
59
  async def drop(self):
60
  self._data = {}
61
 
 
62
  @dataclass
63
  class NanoVectorDBStorage(BaseVectorStorage):
64
  cosine_better_than_threshold: float = 0.2
65
 
66
  def __post_init__(self):
67
-
68
  self._client_file_name = os.path.join(
69
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
70
  )
@@ -118,6 +114,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
118
  async def index_done_callback(self):
119
  self._client.save()
120
 
 
121
  @dataclass
122
  class NetworkXStorage(BaseGraphStorage):
123
  @staticmethod
@@ -142,7 +139,9 @@ class NetworkXStorage(BaseGraphStorage):
142
 
143
  graph = graph.copy()
144
  graph = cast(nx.Graph, largest_connected_component(graph))
145
- node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
 
 
146
  graph = nx.relabel_nodes(graph, node_mapping)
147
  return NetworkXStorage._stabilize_graph(graph)
148
 
 
1
  import asyncio
2
  import html
 
3
  import os
4
+ from dataclasses import dataclass
 
5
  from typing import Any, Union, cast
 
 
6
  import networkx as nx
7
  import numpy as np
8
  from nano_vectordb import NanoVectorDB
 
9
 
10
  from .utils import load_json, logger, write_json
11
  from .base import (
 
14
  BaseVectorStorage,
15
  )
16
 
17
+
18
  @dataclass
19
  class JsonKVStorage(BaseKVStorage):
20
  def __post_init__(self):
 
55
  async def drop(self):
56
  self._data = {}
57
 
58
+
59
  @dataclass
60
  class NanoVectorDBStorage(BaseVectorStorage):
61
  cosine_better_than_threshold: float = 0.2
62
 
63
  def __post_init__(self):
 
64
  self._client_file_name = os.path.join(
65
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
66
  )
 
114
  async def index_done_callback(self):
115
  self._client.save()
116
 
117
+
118
  @dataclass
119
  class NetworkXStorage(BaseGraphStorage):
120
  @staticmethod
 
139
 
140
  graph = graph.copy()
141
  graph = cast(nx.Graph, largest_connected_component(graph))
142
+ node_mapping = {
143
+ node: html.unescape(node.upper().strip()) for node in graph.nodes()
144
+ } # type: ignore
145
  graph = nx.relabel_nodes(graph, node_mapping)
146
  return NetworkXStorage._stabilize_graph(graph)
147
 
lightrag/utils.py CHANGED
@@ -8,6 +8,7 @@ from dataclasses import dataclass
8
  from functools import wraps
9
  from hashlib import md5
10
  from typing import Any, Union
 
11
 
12
  import numpy as np
13
  import tiktoken
@@ -16,18 +17,22 @@ ENCODER = None
16
 
17
  logger = logging.getLogger("lightrag")
18
 
 
19
  def set_logger(log_file: str):
20
  logger.setLevel(logging.DEBUG)
21
 
22
  file_handler = logging.FileHandler(log_file)
23
  file_handler.setLevel(logging.DEBUG)
24
 
25
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
 
26
  file_handler.setFormatter(formatter)
27
 
28
  if not logger.handlers:
29
  logger.addHandler(file_handler)
30
 
 
31
  @dataclass
32
  class EmbeddingFunc:
33
  embedding_dim: int
@@ -36,7 +41,8 @@ class EmbeddingFunc:
36
 
37
  async def __call__(self, *args, **kwargs) -> np.ndarray:
38
  return await self.func(*args, **kwargs)
39
-
 
40
  def locate_json_string_body_from_string(content: str) -> Union[str, None]:
41
  """Locate the JSON string body from a string"""
42
  maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -45,6 +51,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
45
  else:
46
  return None
47
 
 
48
  def convert_response_to_json(response: str) -> dict:
49
  json_str = locate_json_string_body_from_string(response)
50
  assert json_str is not None, f"Unable to parse JSON from response: {response}"
@@ -55,12 +62,15 @@ def convert_response_to_json(response: str) -> dict:
55
  logger.error(f"Failed to parse JSON: {json_str}")
56
  raise e from None
57
 
 
58
  def compute_args_hash(*args):
59
  return md5(str(args).encode()).hexdigest()
60
 
 
61
  def compute_mdhash_id(content, prefix: str = ""):
62
  return prefix + md5(content.encode()).hexdigest()
63
 
 
64
  def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
65
  """Add restriction of maximum async calling times for a async func"""
66
 
@@ -82,6 +92,7 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
82
 
83
  return final_decro
84
 
 
85
  def wrap_embedding_func_with_attrs(**kwargs):
86
  """Wrap a function with attributes"""
87
 
@@ -91,16 +102,19 @@ def wrap_embedding_func_with_attrs(**kwargs):
91
 
92
  return final_decro
93
 
 
94
  def load_json(file_name):
95
  if not os.path.exists(file_name):
96
  return None
97
  with open(file_name, encoding="utf-8") as f:
98
  return json.load(f)
99
 
 
100
  def write_json(json_obj, file_name):
101
  with open(file_name, "w", encoding="utf-8") as f:
102
  json.dump(json_obj, f, indent=2, ensure_ascii=False)
103
 
 
104
  def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
105
  global ENCODER
106
  if ENCODER is None:
@@ -116,12 +130,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
116
  content = ENCODER.decode(tokens)
117
  return content
118
 
 
119
  def pack_user_ass_to_openai_messages(*args: str):
120
  roles = ["user", "assistant"]
121
  return [
122
  {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
123
  ]
124
 
 
125
  def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
126
  """Split a string by multiple markers"""
127
  if not markers:
@@ -129,6 +145,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
129
  results = re.split("|".join(re.escape(marker) for marker in markers), content)
130
  return [r.strip() for r in results if r.strip()]
131
 
 
132
  # Refer the utils functions of the official GraphRAG implementation:
133
  # https://github.com/microsoft/graphrag
134
  def clean_str(input: Any) -> str:
@@ -141,9 +158,11 @@ def clean_str(input: Any) -> str:
141
  # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
142
  return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
143
 
 
144
  def is_float_regex(value):
145
  return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
146
 
 
147
  def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
148
  """Truncate a list of data by token size"""
149
  if max_token_size <= 0:
@@ -155,11 +174,61 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
155
  return list_data[:i]
156
  return list_data
157
 
 
158
  def list_of_list_to_csv(data: list[list]):
159
  return "\n".join(
160
  [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
161
  )
162
 
 
163
  def save_data_to_file(data, file_name):
164
- with open(file_name, 'w', encoding='utf-8') as f:
165
- json.dump(data, f, ensure_ascii=False, indent=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from functools import wraps
9
  from hashlib import md5
10
  from typing import Any, Union
11
+ import xml.etree.ElementTree as ET
12
 
13
  import numpy as np
14
  import tiktoken
 
17
 
18
  logger = logging.getLogger("lightrag")
19
 
20
+
21
  def set_logger(log_file: str):
22
  logger.setLevel(logging.DEBUG)
23
 
24
  file_handler = logging.FileHandler(log_file)
25
  file_handler.setLevel(logging.DEBUG)
26
 
27
+ formatter = logging.Formatter(
28
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
29
+ )
30
  file_handler.setFormatter(formatter)
31
 
32
  if not logger.handlers:
33
  logger.addHandler(file_handler)
34
 
35
+
36
  @dataclass
37
  class EmbeddingFunc:
38
  embedding_dim: int
 
41
 
42
  async def __call__(self, *args, **kwargs) -> np.ndarray:
43
  return await self.func(*args, **kwargs)
44
+
45
+
46
  def locate_json_string_body_from_string(content: str) -> Union[str, None]:
47
  """Locate the JSON string body from a string"""
48
  maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
 
51
  else:
52
  return None
53
 
54
+
55
  def convert_response_to_json(response: str) -> dict:
56
  json_str = locate_json_string_body_from_string(response)
57
  assert json_str is not None, f"Unable to parse JSON from response: {response}"
 
62
  logger.error(f"Failed to parse JSON: {json_str}")
63
  raise e from None
64
 
65
+
66
  def compute_args_hash(*args):
67
  return md5(str(args).encode()).hexdigest()
68
 
69
+
70
  def compute_mdhash_id(content, prefix: str = ""):
71
  return prefix + md5(content.encode()).hexdigest()
72
 
73
+
74
  def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
75
  """Add restriction of maximum async calling times for a async func"""
76
 
 
92
 
93
  return final_decro
94
 
95
+
96
  def wrap_embedding_func_with_attrs(**kwargs):
97
  """Wrap a function with attributes"""
98
 
 
102
 
103
  return final_decro
104
 
105
+
106
  def load_json(file_name):
107
  if not os.path.exists(file_name):
108
  return None
109
  with open(file_name, encoding="utf-8") as f:
110
  return json.load(f)
111
 
112
+
113
  def write_json(json_obj, file_name):
114
  with open(file_name, "w", encoding="utf-8") as f:
115
  json.dump(json_obj, f, indent=2, ensure_ascii=False)
116
 
117
+
118
  def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
119
  global ENCODER
120
  if ENCODER is None:
 
130
  content = ENCODER.decode(tokens)
131
  return content
132
 
133
+
134
  def pack_user_ass_to_openai_messages(*args: str):
135
  roles = ["user", "assistant"]
136
  return [
137
  {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
138
  ]
139
 
140
+
141
  def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
142
  """Split a string by multiple markers"""
143
  if not markers:
 
145
  results = re.split("|".join(re.escape(marker) for marker in markers), content)
146
  return [r.strip() for r in results if r.strip()]
147
 
148
+
149
  # Refer the utils functions of the official GraphRAG implementation:
150
  # https://github.com/microsoft/graphrag
151
  def clean_str(input: Any) -> str:
 
158
  # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
159
  return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
160
 
161
+
162
  def is_float_regex(value):
163
  return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
164
 
165
+
166
  def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
167
  """Truncate a list of data by token size"""
168
  if max_token_size <= 0:
 
174
  return list_data[:i]
175
  return list_data
176
 
177
+
178
  def list_of_list_to_csv(data: list[list]):
179
  return "\n".join(
180
  [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
181
  )
182
 
183
+
184
  def save_data_to_file(data, file_name):
185
+ with open(file_name, "w", encoding="utf-8") as f:
186
+ json.dump(data, f, ensure_ascii=False, indent=4)
187
+
188
+ def xml_to_json(xml_file):
189
+ try:
190
+ tree = ET.parse(xml_file)
191
+ root = tree.getroot()
192
+
193
+ # Print the root element's tag and attributes to confirm the file has been correctly loaded
194
+ print(f"Root element: {root.tag}")
195
+ print(f"Root attributes: {root.attrib}")
196
+
197
+ data = {
198
+ "nodes": [],
199
+ "edges": []
200
+ }
201
+
202
+ # Use namespace
203
+ namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
204
+
205
+ for node in root.findall('.//node', namespace):
206
+ node_data = {
207
+ "id": node.get('id').strip('"'),
208
+ "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
209
+ "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
210
+ "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
211
+ }
212
+ data["nodes"].append(node_data)
213
+
214
+ for edge in root.findall('.//edge', namespace):
215
+ edge_data = {
216
+ "source": edge.get('source').strip('"'),
217
+ "target": edge.get('target').strip('"'),
218
+ "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
219
+ "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
220
+ "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
221
+ "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
222
+ }
223
+ data["edges"].append(edge_data)
224
+
225
+ # Print the number of nodes and edges found
226
+ print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
227
+
228
+ return data
229
+ except ET.ParseError as e:
230
+ print(f"Error parsing XML file: {e}")
231
+ return None
232
+ except Exception as e:
233
+ print(f"An error occurred: {e}")
234
+ return None
reproduce/Step_0.py CHANGED
@@ -3,11 +3,11 @@ import json
3
  import glob
4
  import argparse
5
 
6
- def extract_unique_contexts(input_directory, output_directory):
7
 
 
8
  os.makedirs(output_directory, exist_ok=True)
9
 
10
- jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl'))
11
  print(f"Found {len(jsonl_files)} JSONL files.")
12
 
13
  for file_path in jsonl_files:
@@ -21,18 +21,20 @@ def extract_unique_contexts(input_directory, output_directory):
21
  print(f"Processing file: {filename}")
22
 
23
  try:
24
- with open(file_path, 'r', encoding='utf-8') as infile:
25
  for line_number, line in enumerate(infile, start=1):
26
  line = line.strip()
27
  if not line:
28
  continue
29
  try:
30
  json_obj = json.loads(line)
31
- context = json_obj.get('context')
32
  if context and context not in unique_contexts_dict:
33
  unique_contexts_dict[context] = None
34
  except json.JSONDecodeError as e:
35
- print(f"JSON decoding error in file {filename} at line {line_number}: {e}")
 
 
36
  except FileNotFoundError:
37
  print(f"File not found: {filename}")
38
  continue
@@ -41,10 +43,12 @@ def extract_unique_contexts(input_directory, output_directory):
41
  continue
42
 
43
  unique_contexts_list = list(unique_contexts_dict.keys())
44
- print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.")
 
 
45
 
46
  try:
47
- with open(output_path, 'w', encoding='utf-8') as outfile:
48
  json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
49
  print(f"Unique `context` entries have been saved to: {output_filename}")
50
  except Exception as e:
@@ -55,8 +59,10 @@ def extract_unique_contexts(input_directory, output_directory):
55
 
56
  if __name__ == "__main__":
57
  parser = argparse.ArgumentParser()
58
- parser.add_argument('-i', '--input_dir', type=str, default='../datasets')
59
- parser.add_argument('-o', '--output_dir', type=str, default='../datasets/unique_contexts')
 
 
60
 
61
  args = parser.parse_args()
62
 
 
3
  import glob
4
  import argparse
5
 
 
6
 
7
+ def extract_unique_contexts(input_directory, output_directory):
8
  os.makedirs(output_directory, exist_ok=True)
9
 
10
+ jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl"))
11
  print(f"Found {len(jsonl_files)} JSONL files.")
12
 
13
  for file_path in jsonl_files:
 
21
  print(f"Processing file: {filename}")
22
 
23
  try:
24
+ with open(file_path, "r", encoding="utf-8") as infile:
25
  for line_number, line in enumerate(infile, start=1):
26
  line = line.strip()
27
  if not line:
28
  continue
29
  try:
30
  json_obj = json.loads(line)
31
+ context = json_obj.get("context")
32
  if context and context not in unique_contexts_dict:
33
  unique_contexts_dict[context] = None
34
  except json.JSONDecodeError as e:
35
+ print(
36
+ f"JSON decoding error in file {filename} at line {line_number}: {e}"
37
+ )
38
  except FileNotFoundError:
39
  print(f"File not found: {filename}")
40
  continue
 
43
  continue
44
 
45
  unique_contexts_list = list(unique_contexts_dict.keys())
46
+ print(
47
+ f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}."
48
+ )
49
 
50
  try:
51
+ with open(output_path, "w", encoding="utf-8") as outfile:
52
  json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
53
  print(f"Unique `context` entries have been saved to: {output_filename}")
54
  except Exception as e:
 
59
 
60
  if __name__ == "__main__":
61
  parser = argparse.ArgumentParser()
62
+ parser.add_argument("-i", "--input_dir", type=str, default="../datasets")
63
+ parser.add_argument(
64
+ "-o", "--output_dir", type=str, default="../datasets/unique_contexts"
65
+ )
66
 
67
  args = parser.parse_args()
68
 
reproduce/Step_1.py CHANGED
@@ -4,10 +4,11 @@ import time
4
 
5
  from lightrag import LightRAG
6
 
 
7
  def insert_text(rag, file_path):
8
- with open(file_path, mode='r') as f:
9
  unique_contexts = json.load(f)
10
-
11
  retries = 0
12
  max_retries = 3
13
  while retries < max_retries:
@@ -21,6 +22,7 @@ def insert_text(rag, file_path):
21
  if retries == max_retries:
22
  print("Insertion failed after exceeding the maximum number of retries")
23
 
 
24
  cls = "agriculture"
25
  WORKING_DIR = "../{cls}"
26
 
@@ -29,4 +31,4 @@ if not os.path.exists(WORKING_DIR):
29
 
30
  rag = LightRAG(working_dir=WORKING_DIR)
31
 
32
- insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
 
4
 
5
  from lightrag import LightRAG
6
 
7
+
8
  def insert_text(rag, file_path):
9
+ with open(file_path, mode="r") as f:
10
  unique_contexts = json.load(f)
11
+
12
  retries = 0
13
  max_retries = 3
14
  while retries < max_retries:
 
22
  if retries == max_retries:
23
  print("Insertion failed after exceeding the maximum number of retries")
24
 
25
+
26
  cls = "agriculture"
27
  WORKING_DIR = "../{cls}"
28
 
 
31
 
32
  rag = LightRAG(working_dir=WORKING_DIR)
33
 
34
+ insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
reproduce/Step_1_openai_compatible.py CHANGED
@@ -7,6 +7,7 @@ from lightrag import LightRAG
7
  from lightrag.utils import EmbeddingFunc
8
  from lightrag.llm import openai_complete_if_cache, openai_embedding
9
 
 
10
  ## For Upstage API
11
  # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
12
  async def llm_model_func(
@@ -19,22 +20,26 @@ async def llm_model_func(
19
  history_messages=history_messages,
20
  api_key=os.getenv("UPSTAGE_API_KEY"),
21
  base_url="https://api.upstage.ai/v1/solar",
22
- **kwargs
23
  )
24
 
 
25
  async def embedding_func(texts: list[str]) -> np.ndarray:
26
  return await openai_embedding(
27
  texts,
28
  model="solar-embedding-1-large-query",
29
  api_key=os.getenv("UPSTAGE_API_KEY"),
30
- base_url="https://api.upstage.ai/v1/solar"
31
  )
 
 
32
  ## /For Upstage API
33
 
 
34
  def insert_text(rag, file_path):
35
- with open(file_path, mode='r') as f:
36
  unique_contexts = json.load(f)
37
-
38
  retries = 0
39
  max_retries = 3
40
  while retries < max_retries:
@@ -48,19 +53,19 @@ def insert_text(rag, file_path):
48
  if retries == max_retries:
49
  print("Insertion failed after exceeding the maximum number of retries")
50
 
 
51
  cls = "mix"
52
  WORKING_DIR = f"../{cls}"
53
 
54
  if not os.path.exists(WORKING_DIR):
55
  os.mkdir(WORKING_DIR)
56
 
57
- rag = LightRAG(working_dir=WORKING_DIR,
58
- llm_model_func=llm_model_func,
59
- embedding_func=EmbeddingFunc(
60
- embedding_dim=4096,
61
- max_token_size=8192,
62
- func=embedding_func
63
- )
64
- )
65
 
66
  insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
 
7
  from lightrag.utils import EmbeddingFunc
8
  from lightrag.llm import openai_complete_if_cache, openai_embedding
9
 
10
+
11
  ## For Upstage API
12
  # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
13
  async def llm_model_func(
 
20
  history_messages=history_messages,
21
  api_key=os.getenv("UPSTAGE_API_KEY"),
22
  base_url="https://api.upstage.ai/v1/solar",
23
+ **kwargs,
24
  )
25
 
26
+
27
  async def embedding_func(texts: list[str]) -> np.ndarray:
28
  return await openai_embedding(
29
  texts,
30
  model="solar-embedding-1-large-query",
31
  api_key=os.getenv("UPSTAGE_API_KEY"),
32
+ base_url="https://api.upstage.ai/v1/solar",
33
  )
34
+
35
+
36
  ## /For Upstage API
37
 
38
+
39
  def insert_text(rag, file_path):
40
+ with open(file_path, mode="r") as f:
41
  unique_contexts = json.load(f)
42
+
43
  retries = 0
44
  max_retries = 3
45
  while retries < max_retries:
 
53
  if retries == max_retries:
54
  print("Insertion failed after exceeding the maximum number of retries")
55
 
56
+
57
  cls = "mix"
58
  WORKING_DIR = f"../{cls}"
59
 
60
  if not os.path.exists(WORKING_DIR):
61
  os.mkdir(WORKING_DIR)
62
 
63
+ rag = LightRAG(
64
+ working_dir=WORKING_DIR,
65
+ llm_model_func=llm_model_func,
66
+ embedding_func=EmbeddingFunc(
67
+ embedding_dim=4096, max_token_size=8192, func=embedding_func
68
+ ),
69
+ )
 
70
 
71
  insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
reproduce/Step_2.py CHANGED
@@ -1,8 +1,8 @@
1
- import os
2
  import json
3
  from openai import OpenAI
4
  from transformers import GPT2Tokenizer
5
 
 
6
  def openai_complete_if_cache(
7
  model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
8
  ) -> str:
@@ -19,24 +19,26 @@ def openai_complete_if_cache(
19
  )
20
  return response.choices[0].message.content
21
 
22
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
 
 
23
 
24
  def get_summary(context, tot_tokens=2000):
25
  tokens = tokenizer.tokenize(context)
26
  half_tokens = tot_tokens // 2
27
 
28
- start_tokens = tokens[1000:1000 + half_tokens]
29
- end_tokens = tokens[-(1000 + half_tokens):1000]
30
 
31
  summary_tokens = start_tokens + end_tokens
32
  summary = tokenizer.convert_tokens_to_string(summary_tokens)
33
-
34
  return summary
35
 
36
 
37
- clses = ['agriculture']
38
  for cls in clses:
39
- with open(f'../datasets/unique_contexts/{cls}_unique_contexts.json', mode='r') as f:
40
  unique_contexts = json.load(f)
41
 
42
  summaries = [get_summary(context) for context in unique_contexts]
@@ -67,10 +69,10 @@ for cls in clses:
67
  ...
68
  """
69
 
70
- result = openai_complete_if_cache(model='gpt-4o', prompt=prompt)
71
 
72
  file_path = f"../datasets/questions/{cls}_questions.txt"
73
  with open(file_path, "w") as file:
74
  file.write(result)
75
 
76
- print(f"{cls}_questions written to {file_path}")
 
 
1
  import json
2
  from openai import OpenAI
3
  from transformers import GPT2Tokenizer
4
 
5
+
6
  def openai_complete_if_cache(
7
  model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
8
  ) -> str:
 
19
  )
20
  return response.choices[0].message.content
21
 
22
+
23
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
24
+
25
 
26
  def get_summary(context, tot_tokens=2000):
27
  tokens = tokenizer.tokenize(context)
28
  half_tokens = tot_tokens // 2
29
 
30
+ start_tokens = tokens[1000 : 1000 + half_tokens]
31
+ end_tokens = tokens[-(1000 + half_tokens) : 1000]
32
 
33
  summary_tokens = start_tokens + end_tokens
34
  summary = tokenizer.convert_tokens_to_string(summary_tokens)
35
+
36
  return summary
37
 
38
 
39
+ clses = ["agriculture"]
40
  for cls in clses:
41
+ with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f:
42
  unique_contexts = json.load(f)
43
 
44
  summaries = [get_summary(context) for context in unique_contexts]
 
69
  ...
70
  """
71
 
72
+ result = openai_complete_if_cache(model="gpt-4o", prompt=prompt)
73
 
74
  file_path = f"../datasets/questions/{cls}_questions.txt"
75
  with open(file_path, "w") as file:
76
  file.write(result)
77
 
78
+ print(f"{cls}_questions written to {file_path}")
reproduce/Step_3.py CHANGED
@@ -4,16 +4,18 @@ import asyncio
4
  from lightrag import LightRAG, QueryParam
5
  from tqdm import tqdm
6
 
 
7
  def extract_queries(file_path):
8
- with open(file_path, 'r') as f:
9
  data = f.read()
10
-
11
- data = data.replace('**', '')
12
 
13
- queries = re.findall(r'- Question \d+: (.+)', data)
 
 
14
 
15
  return queries
16
 
 
17
  async def process_query(query_text, rag_instance, query_param):
18
  try:
19
  result, context = await rag_instance.aquery(query_text, param=query_param)
@@ -21,6 +23,7 @@ async def process_query(query_text, rag_instance, query_param):
21
  except Exception as e:
22
  return None, {"query": query_text, "error": str(e)}
23
 
 
24
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
25
  try:
26
  loop = asyncio.get_event_loop()
@@ -29,15 +32,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
29
  asyncio.set_event_loop(loop)
30
  return loop
31
 
32
- def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
 
 
 
33
  loop = always_get_an_event_loop()
34
 
35
- with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
 
 
36
  result_file.write("[\n")
37
  first_entry = True
38
 
39
  for query_text in tqdm(queries, desc="Processing queries", unit="query"):
40
- result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
 
 
41
 
42
  if result:
43
  if not first_entry:
@@ -50,6 +60,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
50
 
51
  result_file.write("\n]")
52
 
 
53
  if __name__ == "__main__":
54
  cls = "agriculture"
55
  mode = "hybrid"
@@ -59,4 +70,6 @@ if __name__ == "__main__":
59
  query_param = QueryParam(mode=mode)
60
 
61
  queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
62
- run_queries_and_save_to_json(queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json")
 
 
 
4
  from lightrag import LightRAG, QueryParam
5
  from tqdm import tqdm
6
 
7
+
8
  def extract_queries(file_path):
9
+ with open(file_path, "r") as f:
10
  data = f.read()
 
 
11
 
12
+ data = data.replace("**", "")
13
+
14
+ queries = re.findall(r"- Question \d+: (.+)", data)
15
 
16
  return queries
17
 
18
+
19
  async def process_query(query_text, rag_instance, query_param):
20
  try:
21
  result, context = await rag_instance.aquery(query_text, param=query_param)
 
23
  except Exception as e:
24
  return None, {"query": query_text, "error": str(e)}
25
 
26
+
27
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
28
  try:
29
  loop = asyncio.get_event_loop()
 
32
  asyncio.set_event_loop(loop)
33
  return loop
34
 
35
+
36
+ def run_queries_and_save_to_json(
37
+ queries, rag_instance, query_param, output_file, error_file
38
+ ):
39
  loop = always_get_an_event_loop()
40
 
41
+ with open(output_file, "a", encoding="utf-8") as result_file, open(
42
+ error_file, "a", encoding="utf-8"
43
+ ) as err_file:
44
  result_file.write("[\n")
45
  first_entry = True
46
 
47
  for query_text in tqdm(queries, desc="Processing queries", unit="query"):
48
+ result, error = loop.run_until_complete(
49
+ process_query(query_text, rag_instance, query_param)
50
+ )
51
 
52
  if result:
53
  if not first_entry:
 
60
 
61
  result_file.write("\n]")
62
 
63
+
64
  if __name__ == "__main__":
65
  cls = "agriculture"
66
  mode = "hybrid"
 
70
  query_param = QueryParam(mode=mode)
71
 
72
  queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
73
+ run_queries_and_save_to_json(
74
+ queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json"
75
+ )
reproduce/Step_3_openai_compatible.py CHANGED
@@ -8,6 +8,7 @@ from lightrag.llm import openai_complete_if_cache, openai_embedding
8
  from lightrag.utils import EmbeddingFunc
9
  import numpy as np
10
 
 
11
  ## For Upstage API
12
  # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
13
  async def llm_model_func(
@@ -20,28 +21,33 @@ async def llm_model_func(
20
  history_messages=history_messages,
21
  api_key=os.getenv("UPSTAGE_API_KEY"),
22
  base_url="https://api.upstage.ai/v1/solar",
23
- **kwargs
24
  )
25
 
 
26
  async def embedding_func(texts: list[str]) -> np.ndarray:
27
  return await openai_embedding(
28
  texts,
29
  model="solar-embedding-1-large-query",
30
  api_key=os.getenv("UPSTAGE_API_KEY"),
31
- base_url="https://api.upstage.ai/v1/solar"
32
  )
 
 
33
  ## /For Upstage API
34
 
 
35
  def extract_queries(file_path):
36
- with open(file_path, 'r') as f:
37
  data = f.read()
38
-
39
- data = data.replace('**', '')
40
 
41
- queries = re.findall(r'- Question \d+: (.+)', data)
 
 
42
 
43
  return queries
44
 
 
45
  async def process_query(query_text, rag_instance, query_param):
46
  try:
47
  result, context = await rag_instance.aquery(query_text, param=query_param)
@@ -49,6 +55,7 @@ async def process_query(query_text, rag_instance, query_param):
49
  except Exception as e:
50
  return None, {"query": query_text, "error": str(e)}
51
 
 
52
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
53
  try:
54
  loop = asyncio.get_event_loop()
@@ -57,15 +64,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
57
  asyncio.set_event_loop(loop)
58
  return loop
59
 
60
- def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
 
 
 
61
  loop = always_get_an_event_loop()
62
 
63
- with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
 
 
64
  result_file.write("[\n")
65
  first_entry = True
66
 
67
  for query_text in tqdm(queries, desc="Processing queries", unit="query"):
68
- result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
 
 
69
 
70
  if result:
71
  if not first_entry:
@@ -78,22 +92,24 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
78
 
79
  result_file.write("\n]")
80
 
 
81
  if __name__ == "__main__":
82
  cls = "mix"
83
  mode = "hybrid"
84
  WORKING_DIR = f"../{cls}"
85
 
86
  rag = LightRAG(working_dir=WORKING_DIR)
87
- rag = LightRAG(working_dir=WORKING_DIR,
88
- llm_model_func=llm_model_func,
89
- embedding_func=EmbeddingFunc(
90
- embedding_dim=4096,
91
- max_token_size=8192,
92
- func=embedding_func
93
- )
94
- )
95
  query_param = QueryParam(mode=mode)
96
 
97
- base_dir='../datasets/questions'
98
  queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
99
- run_queries_and_save_to_json(queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json")
 
 
 
8
  from lightrag.utils import EmbeddingFunc
9
  import numpy as np
10
 
11
+
12
  ## For Upstage API
13
  # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
14
  async def llm_model_func(
 
21
  history_messages=history_messages,
22
  api_key=os.getenv("UPSTAGE_API_KEY"),
23
  base_url="https://api.upstage.ai/v1/solar",
24
+ **kwargs,
25
  )
26
 
27
+
28
  async def embedding_func(texts: list[str]) -> np.ndarray:
29
  return await openai_embedding(
30
  texts,
31
  model="solar-embedding-1-large-query",
32
  api_key=os.getenv("UPSTAGE_API_KEY"),
33
+ base_url="https://api.upstage.ai/v1/solar",
34
  )
35
+
36
+
37
  ## /For Upstage API
38
 
39
+
40
  def extract_queries(file_path):
41
+ with open(file_path, "r") as f:
42
  data = f.read()
 
 
43
 
44
+ data = data.replace("**", "")
45
+
46
+ queries = re.findall(r"- Question \d+: (.+)", data)
47
 
48
  return queries
49
 
50
+
51
  async def process_query(query_text, rag_instance, query_param):
52
  try:
53
  result, context = await rag_instance.aquery(query_text, param=query_param)
 
55
  except Exception as e:
56
  return None, {"query": query_text, "error": str(e)}
57
 
58
+
59
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
60
  try:
61
  loop = asyncio.get_event_loop()
 
64
  asyncio.set_event_loop(loop)
65
  return loop
66
 
67
+
68
+ def run_queries_and_save_to_json(
69
+ queries, rag_instance, query_param, output_file, error_file
70
+ ):
71
  loop = always_get_an_event_loop()
72
 
73
+ with open(output_file, "a", encoding="utf-8") as result_file, open(
74
+ error_file, "a", encoding="utf-8"
75
+ ) as err_file:
76
  result_file.write("[\n")
77
  first_entry = True
78
 
79
  for query_text in tqdm(queries, desc="Processing queries", unit="query"):
80
+ result, error = loop.run_until_complete(
81
+ process_query(query_text, rag_instance, query_param)
82
+ )
83
 
84
  if result:
85
  if not first_entry:
 
92
 
93
  result_file.write("\n]")
94
 
95
+
96
  if __name__ == "__main__":
97
  cls = "mix"
98
  mode = "hybrid"
99
  WORKING_DIR = f"../{cls}"
100
 
101
  rag = LightRAG(working_dir=WORKING_DIR)
102
+ rag = LightRAG(
103
+ working_dir=WORKING_DIR,
104
+ llm_model_func=llm_model_func,
105
+ embedding_func=EmbeddingFunc(
106
+ embedding_dim=4096, max_token_size=8192, func=embedding_func
107
+ ),
108
+ )
 
109
  query_param = QueryParam(mode=mode)
110
 
111
+ base_dir = "../datasets/questions"
112
  queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
113
+ run_queries_and_save_to_json(
114
+ queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json"
115
+ )
requirements.txt CHANGED
@@ -1,12 +1,14 @@
1
- openai
2
- tiktoken
3
- networkx
4
  graspologic
5
- nano-vectordb
6
  hnswlib
7
- xxhash
 
 
 
8
  tenacity
9
- transformers
10
  torch
11
- ollama
12
- accelerate
 
 
1
+ accelerate
2
+ aioboto3
 
3
  graspologic
 
4
  hnswlib
5
+ nano-vectordb
6
+ networkx
7
+ ollama
8
+ openai
9
  tenacity
10
+ tiktoken
11
  torch
12
+ transformers
13
+ xxhash
14
+ pyvis