Sanketh Kumar
commited on
Commit
·
df22b26
1
Parent(s):
e600966
chore: added pre-commit-hooks and ruff formatting for commit-hooks
Browse files- .gitignore +2 -1
- .pre-commit-config.yaml +22 -0
- README.md +25 -25
- examples/batch_eval.py +17 -21
- examples/generate_query.py +4 -5
- examples/lightrag_azure_openai_demo.py +1 -1
- examples/lightrag_bedrock_demo.py +4 -9
- examples/lightrag_hf_demo.py +23 -12
- examples/lightrag_ollama_demo.py +15 -10
- examples/lightrag_openai_compatible_demo.py +21 -11
- examples/lightrag_openai_demo.py +14 -8
- lightrag/__init__.py +1 -1
- lightrag/base.py +7 -4
- lightrag/lightrag.py +29 -36
- lightrag/llm.py +144 -79
- lightrag/operate.py +154 -75
- lightrag/prompt.py +4 -10
- lightrag/storage.py +7 -8
- lightrag/utils.py +24 -4
- reproduce/Step_0.py +15 -9
- reproduce/Step_1.py +5 -3
- reproduce/Step_1_openai_compatible.py +17 -12
- reproduce/Step_2.py +11 -9
- reproduce/Step_3.py +21 -8
- reproduce/Step_3_openai_compatible.py +35 -19
- requirements.txt +8 -8
.gitignore
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
__pycache__
|
2 |
*.egg-info
|
3 |
dickens/
|
4 |
-
book.txt
|
|
|
|
1 |
__pycache__
|
2 |
*.egg-info
|
3 |
dickens/
|
4 |
+
book.txt
|
5 |
+
lightrag-dev/
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
repos:
|
2 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
3 |
+
rev: v5.0.0
|
4 |
+
hooks:
|
5 |
+
- id: trailing-whitespace
|
6 |
+
- id: end-of-file-fixer
|
7 |
+
- id: requirements-txt-fixer
|
8 |
+
|
9 |
+
|
10 |
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
11 |
+
rev: v0.6.4
|
12 |
+
hooks:
|
13 |
+
- id: ruff-format
|
14 |
+
- id: ruff
|
15 |
+
args: [--fix]
|
16 |
+
|
17 |
+
|
18 |
+
- repo: https://github.com/mgedmin/check-manifest
|
19 |
+
rev: "0.49"
|
20 |
+
hooks:
|
21 |
+
- id: check-manifest
|
22 |
+
stages: [manual]
|
README.md
CHANGED
@@ -16,16 +16,16 @@
|
|
16 |
<a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
|
17 |
<a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
|
18 |
</p>
|
19 |
-
|
20 |
This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
|
21 |

|
22 |
</div>
|
23 |
|
24 |
-
## 🎉 News
|
25 |
- [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
|
26 |
- [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
|
27 |
-
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
28 |
-
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
29 |
|
30 |
## Install
|
31 |
|
@@ -83,7 +83,7 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
|
|
83 |
<details>
|
84 |
<summary> Using Open AI-like APIs </summary>
|
85 |
|
86 |
-
LightRAG also
|
87 |
```python
|
88 |
async def llm_model_func(
|
89 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
@@ -120,7 +120,7 @@ rag = LightRAG(
|
|
120 |
|
121 |
<details>
|
122 |
<summary> Using Hugging Face Models </summary>
|
123 |
-
|
124 |
If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
125 |
```python
|
126 |
from lightrag.llm import hf_model_complete, hf_embedding
|
@@ -136,7 +136,7 @@ rag = LightRAG(
|
|
136 |
embedding_dim=384,
|
137 |
max_token_size=5000,
|
138 |
func=lambda texts: hf_embedding(
|
139 |
-
texts,
|
140 |
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
141 |
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
142 |
)
|
@@ -148,7 +148,7 @@ rag = LightRAG(
|
|
148 |
<details>
|
149 |
<summary> Using Ollama Models </summary>
|
150 |
If you want to use Ollama models, you only need to set LightRAG as follows:
|
151 |
-
|
152 |
```python
|
153 |
from lightrag.llm import ollama_model_complete, ollama_embedding
|
154 |
|
@@ -162,7 +162,7 @@ rag = LightRAG(
|
|
162 |
embedding_dim=768,
|
163 |
max_token_size=8192,
|
164 |
func=lambda texts: ollama_embedding(
|
165 |
-
texts,
|
166 |
embed_model="nomic-embed-text"
|
167 |
)
|
168 |
),
|
@@ -187,14 +187,14 @@ with open("./newText.txt") as f:
|
|
187 |
```
|
188 |
## Evaluation
|
189 |
### Dataset
|
190 |
-
The dataset used in LightRAG can be
|
191 |
|
192 |
### Generate Query
|
193 |
-
LightRAG uses the following prompt to generate high-level queries, with the corresponding code
|
194 |
|
195 |
<details>
|
196 |
<summary> Prompt </summary>
|
197 |
-
|
198 |
```python
|
199 |
Given the following description of a dataset:
|
200 |
|
@@ -219,18 +219,18 @@ Output the results in the following structure:
|
|
219 |
...
|
220 |
```
|
221 |
</details>
|
222 |
-
|
223 |
### Batch Eval
|
224 |
To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
|
225 |
|
226 |
<details>
|
227 |
<summary> Prompt </summary>
|
228 |
-
|
229 |
```python
|
230 |
---Role---
|
231 |
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
232 |
---Goal---
|
233 |
-
You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
234 |
|
235 |
- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
|
236 |
- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
|
@@ -294,7 +294,7 @@ Output your evaluation in the following JSON format:
|
|
294 |
| **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% |
|
295 |
| **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% |
|
296 |
|
297 |
-
## Reproduce
|
298 |
All the code can be found in the `./reproduce` directory.
|
299 |
|
300 |
### Step-0 Extract Unique Contexts
|
@@ -302,7 +302,7 @@ First, we need to extract unique contexts in the datasets.
|
|
302 |
|
303 |
<details>
|
304 |
<summary> Code </summary>
|
305 |
-
|
306 |
```python
|
307 |
def extract_unique_contexts(input_directory, output_directory):
|
308 |
|
@@ -361,12 +361,12 @@ For the extracted contexts, we insert them into the LightRAG system.
|
|
361 |
|
362 |
<details>
|
363 |
<summary> Code </summary>
|
364 |
-
|
365 |
```python
|
366 |
def insert_text(rag, file_path):
|
367 |
with open(file_path, mode='r') as f:
|
368 |
unique_contexts = json.load(f)
|
369 |
-
|
370 |
retries = 0
|
371 |
max_retries = 3
|
372 |
while retries < max_retries:
|
@@ -384,11 +384,11 @@ def insert_text(rag, file_path):
|
|
384 |
|
385 |
### Step-2 Generate Queries
|
386 |
|
387 |
-
We extract tokens from
|
388 |
|
389 |
<details>
|
390 |
<summary> Code </summary>
|
391 |
-
|
392 |
```python
|
393 |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
394 |
|
@@ -401,7 +401,7 @@ def get_summary(context, tot_tokens=2000):
|
|
401 |
|
402 |
summary_tokens = start_tokens + end_tokens
|
403 |
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
404 |
-
|
405 |
return summary
|
406 |
```
|
407 |
</details>
|
@@ -411,12 +411,12 @@ For the queries generated in Step-2, we will extract them and query LightRAG.
|
|
411 |
|
412 |
<details>
|
413 |
<summary> Code </summary>
|
414 |
-
|
415 |
```python
|
416 |
def extract_queries(file_path):
|
417 |
with open(file_path, 'r') as f:
|
418 |
data = f.read()
|
419 |
-
|
420 |
data = data.replace('**', '')
|
421 |
|
422 |
queries = re.findall(r'- Question \d+: (.+)', data)
|
@@ -470,7 +470,7 @@ def extract_queries(file_path):
|
|
470 |
|
471 |
```python
|
472 |
@article{guo2024lightrag,
|
473 |
-
title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
|
474 |
author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
|
475 |
year={2024},
|
476 |
eprint={2410.05779},
|
|
|
16 |
<a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
|
17 |
<a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
|
18 |
</p>
|
19 |
+
|
20 |
This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
|
21 |

|
22 |
</div>
|
23 |
|
24 |
+
## 🎉 News
|
25 |
- [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
|
26 |
- [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
|
27 |
+
- [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
28 |
+
- [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
|
29 |
|
30 |
## Install
|
31 |
|
|
|
83 |
<details>
|
84 |
<summary> Using Open AI-like APIs </summary>
|
85 |
|
86 |
+
LightRAG also supports Open AI-like chat/embeddings APIs:
|
87 |
```python
|
88 |
async def llm_model_func(
|
89 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
|
|
120 |
|
121 |
<details>
|
122 |
<summary> Using Hugging Face Models </summary>
|
123 |
+
|
124 |
If you want to use Hugging Face models, you only need to set LightRAG as follows:
|
125 |
```python
|
126 |
from lightrag.llm import hf_model_complete, hf_embedding
|
|
|
136 |
embedding_dim=384,
|
137 |
max_token_size=5000,
|
138 |
func=lambda texts: hf_embedding(
|
139 |
+
texts,
|
140 |
tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
|
141 |
embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
|
142 |
)
|
|
|
148 |
<details>
|
149 |
<summary> Using Ollama Models </summary>
|
150 |
If you want to use Ollama models, you only need to set LightRAG as follows:
|
151 |
+
|
152 |
```python
|
153 |
from lightrag.llm import ollama_model_complete, ollama_embedding
|
154 |
|
|
|
162 |
embedding_dim=768,
|
163 |
max_token_size=8192,
|
164 |
func=lambda texts: ollama_embedding(
|
165 |
+
texts,
|
166 |
embed_model="nomic-embed-text"
|
167 |
)
|
168 |
),
|
|
|
187 |
```
|
188 |
## Evaluation
|
189 |
### Dataset
|
190 |
+
The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
|
191 |
|
192 |
### Generate Query
|
193 |
+
LightRAG uses the following prompt to generate high-level queries, with the corresponding code in `example/generate_query.py`.
|
194 |
|
195 |
<details>
|
196 |
<summary> Prompt </summary>
|
197 |
+
|
198 |
```python
|
199 |
Given the following description of a dataset:
|
200 |
|
|
|
219 |
...
|
220 |
```
|
221 |
</details>
|
222 |
+
|
223 |
### Batch Eval
|
224 |
To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
|
225 |
|
226 |
<details>
|
227 |
<summary> Prompt </summary>
|
228 |
+
|
229 |
```python
|
230 |
---Role---
|
231 |
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
232 |
---Goal---
|
233 |
+
You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
234 |
|
235 |
- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
|
236 |
- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
|
|
|
294 |
| **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% |
|
295 |
| **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% |
|
296 |
|
297 |
+
## Reproduce
|
298 |
All the code can be found in the `./reproduce` directory.
|
299 |
|
300 |
### Step-0 Extract Unique Contexts
|
|
|
302 |
|
303 |
<details>
|
304 |
<summary> Code </summary>
|
305 |
+
|
306 |
```python
|
307 |
def extract_unique_contexts(input_directory, output_directory):
|
308 |
|
|
|
361 |
|
362 |
<details>
|
363 |
<summary> Code </summary>
|
364 |
+
|
365 |
```python
|
366 |
def insert_text(rag, file_path):
|
367 |
with open(file_path, mode='r') as f:
|
368 |
unique_contexts = json.load(f)
|
369 |
+
|
370 |
retries = 0
|
371 |
max_retries = 3
|
372 |
while retries < max_retries:
|
|
|
384 |
|
385 |
### Step-2 Generate Queries
|
386 |
|
387 |
+
We extract tokens from the first and the second half of each context in the dataset, then combine them as dataset descriptions to generate queries.
|
388 |
|
389 |
<details>
|
390 |
<summary> Code </summary>
|
391 |
+
|
392 |
```python
|
393 |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
394 |
|
|
|
401 |
|
402 |
summary_tokens = start_tokens + end_tokens
|
403 |
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
404 |
+
|
405 |
return summary
|
406 |
```
|
407 |
</details>
|
|
|
411 |
|
412 |
<details>
|
413 |
<summary> Code </summary>
|
414 |
+
|
415 |
```python
|
416 |
def extract_queries(file_path):
|
417 |
with open(file_path, 'r') as f:
|
418 |
data = f.read()
|
419 |
+
|
420 |
data = data.replace('**', '')
|
421 |
|
422 |
queries = re.findall(r'- Question \d+: (.+)', data)
|
|
|
470 |
|
471 |
```python
|
472 |
@article{guo2024lightrag,
|
473 |
+
title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
|
474 |
author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
|
475 |
year={2024},
|
476 |
eprint={2410.05779},
|
examples/batch_eval.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
import os
|
2 |
import re
|
3 |
import json
|
4 |
import jsonlines
|
@@ -9,28 +8,28 @@ from openai import OpenAI
|
|
9 |
def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
10 |
client = OpenAI()
|
11 |
|
12 |
-
with open(query_file,
|
13 |
data = f.read()
|
14 |
|
15 |
-
queries = re.findall(r
|
16 |
|
17 |
-
with open(result1_file,
|
18 |
answers1 = json.load(f)
|
19 |
-
answers1 = [i[
|
20 |
|
21 |
-
with open(result2_file,
|
22 |
answers2 = json.load(f)
|
23 |
-
answers2 = [i[
|
24 |
|
25 |
requests = []
|
26 |
for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
|
27 |
-
sys_prompt =
|
28 |
---Role---
|
29 |
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
30 |
"""
|
31 |
|
32 |
prompt = f"""
|
33 |
-
You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
34 |
|
35 |
- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
|
36 |
- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
|
@@ -69,7 +68,6 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
|
69 |
}}
|
70 |
"""
|
71 |
|
72 |
-
|
73 |
request_data = {
|
74 |
"custom_id": f"request-{i+1}",
|
75 |
"method": "POST",
|
@@ -78,22 +76,21 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
|
78 |
"model": "gpt-4o-mini",
|
79 |
"messages": [
|
80 |
{"role": "system", "content": sys_prompt},
|
81 |
-
{"role": "user", "content": prompt}
|
82 |
],
|
83 |
-
}
|
84 |
}
|
85 |
-
|
86 |
requests.append(request_data)
|
87 |
|
88 |
-
with jsonlines.open(output_file_path, mode=
|
89 |
for request in requests:
|
90 |
writer.write(request)
|
91 |
|
92 |
print(f"Batch API requests written to {output_file_path}")
|
93 |
|
94 |
batch_input_file = client.files.create(
|
95 |
-
file=open(output_file_path, "rb"),
|
96 |
-
purpose="batch"
|
97 |
)
|
98 |
batch_input_file_id = batch_input_file.id
|
99 |
|
@@ -101,12 +98,11 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
|
101 |
input_file_id=batch_input_file_id,
|
102 |
endpoint="/v1/chat/completions",
|
103 |
completion_window="24h",
|
104 |
-
metadata={
|
105 |
-
"description": "nightly eval job"
|
106 |
-
}
|
107 |
)
|
108 |
|
109 |
-
print(f
|
|
|
110 |
|
111 |
if __name__ == "__main__":
|
112 |
-
batch_eval()
|
|
|
|
|
1 |
import re
|
2 |
import json
|
3 |
import jsonlines
|
|
|
8 |
def batch_eval(query_file, result1_file, result2_file, output_file_path):
|
9 |
client = OpenAI()
|
10 |
|
11 |
+
with open(query_file, "r") as f:
|
12 |
data = f.read()
|
13 |
|
14 |
+
queries = re.findall(r"- Question \d+: (.+)", data)
|
15 |
|
16 |
+
with open(result1_file, "r") as f:
|
17 |
answers1 = json.load(f)
|
18 |
+
answers1 = [i["result"] for i in answers1]
|
19 |
|
20 |
+
with open(result2_file, "r") as f:
|
21 |
answers2 = json.load(f)
|
22 |
+
answers2 = [i["result"] for i in answers2]
|
23 |
|
24 |
requests = []
|
25 |
for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
|
26 |
+
sys_prompt = """
|
27 |
---Role---
|
28 |
You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
29 |
"""
|
30 |
|
31 |
prompt = f"""
|
32 |
+
You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
|
33 |
|
34 |
- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
|
35 |
- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
|
|
|
68 |
}}
|
69 |
"""
|
70 |
|
|
|
71 |
request_data = {
|
72 |
"custom_id": f"request-{i+1}",
|
73 |
"method": "POST",
|
|
|
76 |
"model": "gpt-4o-mini",
|
77 |
"messages": [
|
78 |
{"role": "system", "content": sys_prompt},
|
79 |
+
{"role": "user", "content": prompt},
|
80 |
],
|
81 |
+
},
|
82 |
}
|
83 |
+
|
84 |
requests.append(request_data)
|
85 |
|
86 |
+
with jsonlines.open(output_file_path, mode="w") as writer:
|
87 |
for request in requests:
|
88 |
writer.write(request)
|
89 |
|
90 |
print(f"Batch API requests written to {output_file_path}")
|
91 |
|
92 |
batch_input_file = client.files.create(
|
93 |
+
file=open(output_file_path, "rb"), purpose="batch"
|
|
|
94 |
)
|
95 |
batch_input_file_id = batch_input_file.id
|
96 |
|
|
|
98 |
input_file_id=batch_input_file_id,
|
99 |
endpoint="/v1/chat/completions",
|
100 |
completion_window="24h",
|
101 |
+
metadata={"description": "nightly eval job"},
|
|
|
|
|
102 |
)
|
103 |
|
104 |
+
print(f"Batch {batch.id} has been created.")
|
105 |
+
|
106 |
|
107 |
if __name__ == "__main__":
|
108 |
+
batch_eval()
|
examples/generate_query.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
from openai import OpenAI
|
4 |
|
5 |
# os.environ["OPENAI_API_KEY"] = ""
|
6 |
|
|
|
7 |
def openai_complete_if_cache(
|
8 |
model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
9 |
) -> str:
|
@@ -47,10 +46,10 @@ if __name__ == "__main__":
|
|
47 |
...
|
48 |
"""
|
49 |
|
50 |
-
result = openai_complete_if_cache(model=
|
51 |
|
52 |
-
file_path =
|
53 |
with open(file_path, "w") as file:
|
54 |
file.write(result)
|
55 |
|
56 |
-
print(f"Queries written to {file_path}")
|
|
|
|
|
|
|
1 |
from openai import OpenAI
|
2 |
|
3 |
# os.environ["OPENAI_API_KEY"] = ""
|
4 |
|
5 |
+
|
6 |
def openai_complete_if_cache(
|
7 |
model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
8 |
) -> str:
|
|
|
46 |
...
|
47 |
"""
|
48 |
|
49 |
+
result = openai_complete_if_cache(model="gpt-4o-mini", prompt=prompt)
|
50 |
|
51 |
+
file_path = "./queries.txt"
|
52 |
with open(file_path, "w") as file:
|
53 |
file.write(result)
|
54 |
|
55 |
+
print(f"Queries written to {file_path}")
|
examples/lightrag_azure_openai_demo.py
CHANGED
@@ -122,4 +122,4 @@ print("\nResult (Global):")
|
|
122 |
print(rag.query(query_text, param=QueryParam(mode="global")))
|
123 |
|
124 |
print("\nResult (Hybrid):")
|
125 |
-
print(rag.query(query_text, param=QueryParam(mode="hybrid")))
|
|
|
122 |
print(rag.query(query_text, param=QueryParam(mode="global")))
|
123 |
|
124 |
print("\nResult (Hybrid):")
|
125 |
+
print(rag.query(query_text, param=QueryParam(mode="hybrid")))
|
examples/lightrag_bedrock_demo.py
CHANGED
@@ -20,13 +20,11 @@ rag = LightRAG(
|
|
20 |
llm_model_func=bedrock_complete,
|
21 |
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
|
22 |
embedding_func=EmbeddingFunc(
|
23 |
-
embedding_dim=1024,
|
24 |
-
|
25 |
-
func=bedrock_embedding
|
26 |
-
)
|
27 |
)
|
28 |
|
29 |
-
with open("./book.txt",
|
30 |
rag.insert(f.read())
|
31 |
|
32 |
for mode in ["naive", "local", "global", "hybrid"]:
|
@@ -34,8 +32,5 @@ for mode in ["naive", "local", "global", "hybrid"]:
|
|
34 |
print(f"| {mode.capitalize()} |")
|
35 |
print("+-" + "-" * len(mode) + "-+\n")
|
36 |
print(
|
37 |
-
rag.query(
|
38 |
-
"What are the top themes in this story?",
|
39 |
-
param=QueryParam(mode=mode)
|
40 |
-
)
|
41 |
)
|
|
|
20 |
llm_model_func=bedrock_complete,
|
21 |
llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
|
22 |
embedding_func=EmbeddingFunc(
|
23 |
+
embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
|
24 |
+
),
|
|
|
|
|
25 |
)
|
26 |
|
27 |
+
with open("./book.txt", "r", encoding="utf-8") as f:
|
28 |
rag.insert(f.read())
|
29 |
|
30 |
for mode in ["naive", "local", "global", "hybrid"]:
|
|
|
32 |
print(f"| {mode.capitalize()} |")
|
33 |
print("+-" + "-" * len(mode) + "-+\n")
|
34 |
print(
|
35 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode=mode))
|
|
|
|
|
|
|
36 |
)
|
examples/lightrag_hf_demo.py
CHANGED
@@ -1,10 +1,9 @@
|
|
1 |
import os
|
2 |
-
import sys
|
3 |
|
4 |
from lightrag import LightRAG, QueryParam
|
5 |
from lightrag.llm import hf_model_complete, hf_embedding
|
6 |
from lightrag.utils import EmbeddingFunc
|
7 |
-
from transformers import AutoModel,AutoTokenizer
|
8 |
|
9 |
WORKING_DIR = "./dickens"
|
10 |
|
@@ -13,16 +12,20 @@ if not os.path.exists(WORKING_DIR):
|
|
13 |
|
14 |
rag = LightRAG(
|
15 |
working_dir=WORKING_DIR,
|
16 |
-
llm_model_func=hf_model_complete,
|
17 |
-
llm_model_name=
|
18 |
embedding_func=EmbeddingFunc(
|
19 |
embedding_dim=384,
|
20 |
max_token_size=5000,
|
21 |
func=lambda texts: hf_embedding(
|
22 |
-
texts,
|
23 |
-
tokenizer=AutoTokenizer.from_pretrained(
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
),
|
27 |
)
|
28 |
|
@@ -31,13 +34,21 @@ with open("./book.txt") as f:
|
|
31 |
rag.insert(f.read())
|
32 |
|
33 |
# Perform naive search
|
34 |
-
print(
|
|
|
|
|
35 |
|
36 |
# Perform local search
|
37 |
-
print(
|
|
|
|
|
38 |
|
39 |
# Perform global search
|
40 |
-
print(
|
|
|
|
|
41 |
|
42 |
# Perform hybrid search
|
43 |
-
print(
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
|
3 |
from lightrag import LightRAG, QueryParam
|
4 |
from lightrag.llm import hf_model_complete, hf_embedding
|
5 |
from lightrag.utils import EmbeddingFunc
|
6 |
+
from transformers import AutoModel, AutoTokenizer
|
7 |
|
8 |
WORKING_DIR = "./dickens"
|
9 |
|
|
|
12 |
|
13 |
rag = LightRAG(
|
14 |
working_dir=WORKING_DIR,
|
15 |
+
llm_model_func=hf_model_complete,
|
16 |
+
llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
|
17 |
embedding_func=EmbeddingFunc(
|
18 |
embedding_dim=384,
|
19 |
max_token_size=5000,
|
20 |
func=lambda texts: hf_embedding(
|
21 |
+
texts,
|
22 |
+
tokenizer=AutoTokenizer.from_pretrained(
|
23 |
+
"sentence-transformers/all-MiniLM-L6-v2"
|
24 |
+
),
|
25 |
+
embed_model=AutoModel.from_pretrained(
|
26 |
+
"sentence-transformers/all-MiniLM-L6-v2"
|
27 |
+
),
|
28 |
+
),
|
29 |
),
|
30 |
)
|
31 |
|
|
|
34 |
rag.insert(f.read())
|
35 |
|
36 |
# Perform naive search
|
37 |
+
print(
|
38 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
39 |
+
)
|
40 |
|
41 |
# Perform local search
|
42 |
+
print(
|
43 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
44 |
+
)
|
45 |
|
46 |
# Perform global search
|
47 |
+
print(
|
48 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
49 |
+
)
|
50 |
|
51 |
# Perform hybrid search
|
52 |
+
print(
|
53 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
54 |
+
)
|
examples/lightrag_ollama_demo.py
CHANGED
@@ -11,15 +11,12 @@ if not os.path.exists(WORKING_DIR):
|
|
11 |
|
12 |
rag = LightRAG(
|
13 |
working_dir=WORKING_DIR,
|
14 |
-
llm_model_func=ollama_model_complete,
|
15 |
-
llm_model_name=
|
16 |
embedding_func=EmbeddingFunc(
|
17 |
embedding_dim=768,
|
18 |
max_token_size=8192,
|
19 |
-
func=lambda texts: ollama_embedding(
|
20 |
-
texts,
|
21 |
-
embed_model="nomic-embed-text"
|
22 |
-
)
|
23 |
),
|
24 |
)
|
25 |
|
@@ -28,13 +25,21 @@ with open("./book.txt") as f:
|
|
28 |
rag.insert(f.read())
|
29 |
|
30 |
# Perform naive search
|
31 |
-
print(
|
|
|
|
|
32 |
|
33 |
# Perform local search
|
34 |
-
print(
|
|
|
|
|
35 |
|
36 |
# Perform global search
|
37 |
-
print(
|
|
|
|
|
38 |
|
39 |
# Perform hybrid search
|
40 |
-
print(
|
|
|
|
|
|
11 |
|
12 |
rag = LightRAG(
|
13 |
working_dir=WORKING_DIR,
|
14 |
+
llm_model_func=ollama_model_complete,
|
15 |
+
llm_model_name="your_model_name",
|
16 |
embedding_func=EmbeddingFunc(
|
17 |
embedding_dim=768,
|
18 |
max_token_size=8192,
|
19 |
+
func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
|
|
|
|
|
|
|
20 |
),
|
21 |
)
|
22 |
|
|
|
25 |
rag.insert(f.read())
|
26 |
|
27 |
# Perform naive search
|
28 |
+
print(
|
29 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
30 |
+
)
|
31 |
|
32 |
# Perform local search
|
33 |
+
print(
|
34 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
35 |
+
)
|
36 |
|
37 |
# Perform global search
|
38 |
+
print(
|
39 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
40 |
+
)
|
41 |
|
42 |
# Perform hybrid search
|
43 |
+
print(
|
44 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
45 |
+
)
|
examples/lightrag_openai_compatible_demo.py
CHANGED
@@ -6,10 +6,11 @@ from lightrag.utils import EmbeddingFunc
|
|
6 |
import numpy as np
|
7 |
|
8 |
WORKING_DIR = "./dickens"
|
9 |
-
|
10 |
if not os.path.exists(WORKING_DIR):
|
11 |
os.mkdir(WORKING_DIR)
|
12 |
|
|
|
13 |
async def llm_model_func(
|
14 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
15 |
) -> str:
|
@@ -20,17 +21,19 @@ async def llm_model_func(
|
|
20 |
history_messages=history_messages,
|
21 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
22 |
base_url="https://api.upstage.ai/v1/solar",
|
23 |
-
**kwargs
|
24 |
)
|
25 |
|
|
|
26 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
27 |
return await openai_embedding(
|
28 |
texts,
|
29 |
model="solar-embedding-1-large-query",
|
30 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
31 |
-
base_url="https://api.upstage.ai/v1/solar"
|
32 |
)
|
33 |
|
|
|
34 |
# function test
|
35 |
async def test_funcs():
|
36 |
result = await llm_model_func("How are you?")
|
@@ -39,6 +42,7 @@ async def test_funcs():
|
|
39 |
result = await embedding_func(["How are you?"])
|
40 |
print("embedding_func: ", result)
|
41 |
|
|
|
42 |
asyncio.run(test_funcs())
|
43 |
|
44 |
|
@@ -46,10 +50,8 @@ rag = LightRAG(
|
|
46 |
working_dir=WORKING_DIR,
|
47 |
llm_model_func=llm_model_func,
|
48 |
embedding_func=EmbeddingFunc(
|
49 |
-
embedding_dim=4096,
|
50 |
-
|
51 |
-
func=embedding_func
|
52 |
-
)
|
53 |
)
|
54 |
|
55 |
|
@@ -57,13 +59,21 @@ with open("./book.txt") as f:
|
|
57 |
rag.insert(f.read())
|
58 |
|
59 |
# Perform naive search
|
60 |
-
print(
|
|
|
|
|
61 |
|
62 |
# Perform local search
|
63 |
-
print(
|
|
|
|
|
64 |
|
65 |
# Perform global search
|
66 |
-
print(
|
|
|
|
|
67 |
|
68 |
# Perform hybrid search
|
69 |
-
print(
|
|
|
|
|
|
6 |
import numpy as np
|
7 |
|
8 |
WORKING_DIR = "./dickens"
|
9 |
+
|
10 |
if not os.path.exists(WORKING_DIR):
|
11 |
os.mkdir(WORKING_DIR)
|
12 |
|
13 |
+
|
14 |
async def llm_model_func(
|
15 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
16 |
) -> str:
|
|
|
21 |
history_messages=history_messages,
|
22 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
23 |
base_url="https://api.upstage.ai/v1/solar",
|
24 |
+
**kwargs,
|
25 |
)
|
26 |
|
27 |
+
|
28 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
29 |
return await openai_embedding(
|
30 |
texts,
|
31 |
model="solar-embedding-1-large-query",
|
32 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
33 |
+
base_url="https://api.upstage.ai/v1/solar",
|
34 |
)
|
35 |
|
36 |
+
|
37 |
# function test
|
38 |
async def test_funcs():
|
39 |
result = await llm_model_func("How are you?")
|
|
|
42 |
result = await embedding_func(["How are you?"])
|
43 |
print("embedding_func: ", result)
|
44 |
|
45 |
+
|
46 |
asyncio.run(test_funcs())
|
47 |
|
48 |
|
|
|
50 |
working_dir=WORKING_DIR,
|
51 |
llm_model_func=llm_model_func,
|
52 |
embedding_func=EmbeddingFunc(
|
53 |
+
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
54 |
+
),
|
|
|
|
|
55 |
)
|
56 |
|
57 |
|
|
|
59 |
rag.insert(f.read())
|
60 |
|
61 |
# Perform naive search
|
62 |
+
print(
|
63 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
64 |
+
)
|
65 |
|
66 |
# Perform local search
|
67 |
+
print(
|
68 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
69 |
+
)
|
70 |
|
71 |
# Perform global search
|
72 |
+
print(
|
73 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
74 |
+
)
|
75 |
|
76 |
# Perform hybrid search
|
77 |
+
print(
|
78 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
79 |
+
)
|
examples/lightrag_openai_demo.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
import os
|
2 |
-
import sys
|
3 |
|
4 |
from lightrag import LightRAG, QueryParam
|
5 |
-
from lightrag.llm import gpt_4o_mini_complete
|
6 |
-
from transformers import AutoModel,AutoTokenizer
|
7 |
|
8 |
WORKING_DIR = "./dickens"
|
9 |
|
@@ -12,7 +10,7 @@ if not os.path.exists(WORKING_DIR):
|
|
12 |
|
13 |
rag = LightRAG(
|
14 |
working_dir=WORKING_DIR,
|
15 |
-
llm_model_func=gpt_4o_mini_complete
|
16 |
# llm_model_func=gpt_4o_complete
|
17 |
)
|
18 |
|
@@ -21,13 +19,21 @@ with open("./book.txt") as f:
|
|
21 |
rag.insert(f.read())
|
22 |
|
23 |
# Perform naive search
|
24 |
-
print(
|
|
|
|
|
25 |
|
26 |
# Perform local search
|
27 |
-
print(
|
|
|
|
|
28 |
|
29 |
# Perform global search
|
30 |
-
print(
|
|
|
|
|
31 |
|
32 |
# Perform hybrid search
|
33 |
-
print(
|
|
|
|
|
|
1 |
import os
|
|
|
2 |
|
3 |
from lightrag import LightRAG, QueryParam
|
4 |
+
from lightrag.llm import gpt_4o_mini_complete
|
|
|
5 |
|
6 |
WORKING_DIR = "./dickens"
|
7 |
|
|
|
10 |
|
11 |
rag = LightRAG(
|
12 |
working_dir=WORKING_DIR,
|
13 |
+
llm_model_func=gpt_4o_mini_complete,
|
14 |
# llm_model_func=gpt_4o_complete
|
15 |
)
|
16 |
|
|
|
19 |
rag.insert(f.read())
|
20 |
|
21 |
# Perform naive search
|
22 |
+
print(
|
23 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
|
24 |
+
)
|
25 |
|
26 |
# Perform local search
|
27 |
+
print(
|
28 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
|
29 |
+
)
|
30 |
|
31 |
# Perform global search
|
32 |
+
print(
|
33 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
|
34 |
+
)
|
35 |
|
36 |
# Perform hybrid search
|
37 |
+
print(
|
38 |
+
rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
|
39 |
+
)
|
lightrag/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from .lightrag import LightRAG, QueryParam
|
2 |
|
3 |
__version__ = "0.0.6"
|
4 |
__author__ = "Zirui Guo"
|
|
|
1 |
+
from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
|
2 |
|
3 |
__version__ = "0.0.6"
|
4 |
__author__ = "Zirui Guo"
|
lightrag/base.py
CHANGED
@@ -12,15 +12,16 @@ TextChunkSchema = TypedDict(
|
|
12 |
|
13 |
T = TypeVar("T")
|
14 |
|
|
|
15 |
@dataclass
|
16 |
class QueryParam:
|
17 |
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
18 |
only_need_context: bool = False
|
19 |
response_type: str = "Multiple Paragraphs"
|
20 |
top_k: int = 60
|
21 |
-
max_token_for_text_unit: int = 4000
|
22 |
max_token_for_global_context: int = 4000
|
23 |
-
max_token_for_local_context: int = 4000
|
24 |
|
25 |
|
26 |
@dataclass
|
@@ -36,6 +37,7 @@ class StorageNameSpace:
|
|
36 |
"""commit the storage operations after querying"""
|
37 |
pass
|
38 |
|
|
|
39 |
@dataclass
|
40 |
class BaseVectorStorage(StorageNameSpace):
|
41 |
embedding_func: EmbeddingFunc
|
@@ -50,6 +52,7 @@ class BaseVectorStorage(StorageNameSpace):
|
|
50 |
"""
|
51 |
raise NotImplementedError
|
52 |
|
|
|
53 |
@dataclass
|
54 |
class BaseKVStorage(Generic[T], StorageNameSpace):
|
55 |
async def all_keys(self) -> list[str]:
|
@@ -72,7 +75,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
|
|
72 |
|
73 |
async def drop(self):
|
74 |
raise NotImplementedError
|
75 |
-
|
76 |
|
77 |
@dataclass
|
78 |
class BaseGraphStorage(StorageNameSpace):
|
@@ -113,4 +116,4 @@ class BaseGraphStorage(StorageNameSpace):
|
|
113 |
raise NotImplementedError
|
114 |
|
115 |
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
116 |
-
raise NotImplementedError("Node embedding is not used in lightrag.")
|
|
|
12 |
|
13 |
T = TypeVar("T")
|
14 |
|
15 |
+
|
16 |
@dataclass
|
17 |
class QueryParam:
|
18 |
mode: Literal["local", "global", "hybrid", "naive"] = "global"
|
19 |
only_need_context: bool = False
|
20 |
response_type: str = "Multiple Paragraphs"
|
21 |
top_k: int = 60
|
22 |
+
max_token_for_text_unit: int = 4000
|
23 |
max_token_for_global_context: int = 4000
|
24 |
+
max_token_for_local_context: int = 4000
|
25 |
|
26 |
|
27 |
@dataclass
|
|
|
37 |
"""commit the storage operations after querying"""
|
38 |
pass
|
39 |
|
40 |
+
|
41 |
@dataclass
|
42 |
class BaseVectorStorage(StorageNameSpace):
|
43 |
embedding_func: EmbeddingFunc
|
|
|
52 |
"""
|
53 |
raise NotImplementedError
|
54 |
|
55 |
+
|
56 |
@dataclass
|
57 |
class BaseKVStorage(Generic[T], StorageNameSpace):
|
58 |
async def all_keys(self) -> list[str]:
|
|
|
75 |
|
76 |
async def drop(self):
|
77 |
raise NotImplementedError
|
78 |
+
|
79 |
|
80 |
@dataclass
|
81 |
class BaseGraphStorage(StorageNameSpace):
|
|
|
116 |
raise NotImplementedError
|
117 |
|
118 |
async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
|
119 |
+
raise NotImplementedError("Node embedding is not used in lightrag.")
|
lightrag/lightrag.py
CHANGED
@@ -3,10 +3,12 @@ import os
|
|
3 |
from dataclasses import asdict, dataclass, field
|
4 |
from datetime import datetime
|
5 |
from functools import partial
|
6 |
-
from typing import Type, cast
|
7 |
-
from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
|
8 |
|
9 |
-
from .llm import
|
|
|
|
|
|
|
10 |
from .operate import (
|
11 |
chunking_by_token_size,
|
12 |
extract_entities,
|
@@ -37,6 +39,7 @@ from .base import (
|
|
37 |
QueryParam,
|
38 |
)
|
39 |
|
|
|
40 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
41 |
try:
|
42 |
loop = asyncio.get_running_loop()
|
@@ -69,7 +72,6 @@ class LightRAG:
|
|
69 |
"dimensions": 1536,
|
70 |
"num_walks": 10,
|
71 |
"walk_length": 40,
|
72 |
-
"num_walks": 10,
|
73 |
"window_size": 2,
|
74 |
"iterations": 3,
|
75 |
"random_seed": 3,
|
@@ -77,13 +79,13 @@ class LightRAG:
|
|
77 |
)
|
78 |
|
79 |
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
80 |
-
embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)
|
81 |
embedding_batch_num: int = 32
|
82 |
embedding_func_max_async: int = 16
|
83 |
|
84 |
# LLM
|
85 |
-
llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
|
86 |
-
llm_model_name: str =
|
87 |
llm_model_max_token_size: int = 32768
|
88 |
llm_model_max_async: int = 16
|
89 |
|
@@ -98,11 +100,11 @@ class LightRAG:
|
|
98 |
addon_params: dict = field(default_factory=dict)
|
99 |
convert_response_to_json_func: callable = convert_response_to_json
|
100 |
|
101 |
-
def __post_init__(self):
|
102 |
log_file = os.path.join(self.working_dir, "lightrag.log")
|
103 |
set_logger(log_file)
|
104 |
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
105 |
-
|
106 |
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
107 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
108 |
|
@@ -133,30 +135,24 @@ class LightRAG:
|
|
133 |
self.embedding_func
|
134 |
)
|
135 |
|
136 |
-
self.entities_vdb = (
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
meta_fields={"entity_name"}
|
142 |
-
)
|
143 |
)
|
144 |
-
self.relationships_vdb = (
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
meta_fields={"src_id", "tgt_id"}
|
150 |
-
)
|
151 |
)
|
152 |
-
self.chunks_vdb = (
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
embedding_func=self.embedding_func,
|
157 |
-
)
|
158 |
)
|
159 |
-
|
160 |
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
161 |
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
162 |
)
|
@@ -177,7 +173,7 @@ class LightRAG:
|
|
177 |
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
178 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
179 |
if not len(new_docs):
|
180 |
-
logger.warning(
|
181 |
return
|
182 |
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
183 |
|
@@ -203,7 +199,7 @@ class LightRAG:
|
|
203 |
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
204 |
}
|
205 |
if not len(inserting_chunks):
|
206 |
-
logger.warning(
|
207 |
return
|
208 |
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
209 |
|
@@ -246,7 +242,7 @@ class LightRAG:
|
|
246 |
def query(self, query: str, param: QueryParam = QueryParam()):
|
247 |
loop = always_get_an_event_loop()
|
248 |
return loop.run_until_complete(self.aquery(query, param))
|
249 |
-
|
250 |
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
251 |
if param.mode == "local":
|
252 |
response = await local_query(
|
@@ -290,7 +286,6 @@ class LightRAG:
|
|
290 |
raise ValueError(f"Unknown mode {param.mode}")
|
291 |
await self._query_done()
|
292 |
return response
|
293 |
-
|
294 |
|
295 |
async def _query_done(self):
|
296 |
tasks = []
|
@@ -299,5 +294,3 @@ class LightRAG:
|
|
299 |
continue
|
300 |
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
301 |
await asyncio.gather(*tasks)
|
302 |
-
|
303 |
-
|
|
|
3 |
from dataclasses import asdict, dataclass, field
|
4 |
from datetime import datetime
|
5 |
from functools import partial
|
6 |
+
from typing import Type, cast
|
|
|
7 |
|
8 |
+
from .llm import (
|
9 |
+
gpt_4o_mini_complete,
|
10 |
+
openai_embedding,
|
11 |
+
)
|
12 |
from .operate import (
|
13 |
chunking_by_token_size,
|
14 |
extract_entities,
|
|
|
39 |
QueryParam,
|
40 |
)
|
41 |
|
42 |
+
|
43 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
44 |
try:
|
45 |
loop = asyncio.get_running_loop()
|
|
|
72 |
"dimensions": 1536,
|
73 |
"num_walks": 10,
|
74 |
"walk_length": 40,
|
|
|
75 |
"window_size": 2,
|
76 |
"iterations": 3,
|
77 |
"random_seed": 3,
|
|
|
79 |
)
|
80 |
|
81 |
# embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
|
82 |
+
embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
|
83 |
embedding_batch_num: int = 32
|
84 |
embedding_func_max_async: int = 16
|
85 |
|
86 |
# LLM
|
87 |
+
llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
|
88 |
+
llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
|
89 |
llm_model_max_token_size: int = 32768
|
90 |
llm_model_max_async: int = 16
|
91 |
|
|
|
100 |
addon_params: dict = field(default_factory=dict)
|
101 |
convert_response_to_json_func: callable = convert_response_to_json
|
102 |
|
103 |
+
def __post_init__(self):
|
104 |
log_file = os.path.join(self.working_dir, "lightrag.log")
|
105 |
set_logger(log_file)
|
106 |
logger.info(f"Logger initialized for working directory: {self.working_dir}")
|
107 |
+
|
108 |
_print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
|
109 |
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
|
110 |
|
|
|
135 |
self.embedding_func
|
136 |
)
|
137 |
|
138 |
+
self.entities_vdb = self.vector_db_storage_cls(
|
139 |
+
namespace="entities",
|
140 |
+
global_config=asdict(self),
|
141 |
+
embedding_func=self.embedding_func,
|
142 |
+
meta_fields={"entity_name"},
|
|
|
|
|
143 |
)
|
144 |
+
self.relationships_vdb = self.vector_db_storage_cls(
|
145 |
+
namespace="relationships",
|
146 |
+
global_config=asdict(self),
|
147 |
+
embedding_func=self.embedding_func,
|
148 |
+
meta_fields={"src_id", "tgt_id"},
|
|
|
|
|
149 |
)
|
150 |
+
self.chunks_vdb = self.vector_db_storage_cls(
|
151 |
+
namespace="chunks",
|
152 |
+
global_config=asdict(self),
|
153 |
+
embedding_func=self.embedding_func,
|
|
|
|
|
154 |
)
|
155 |
+
|
156 |
self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
|
157 |
partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
|
158 |
)
|
|
|
173 |
_add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
|
174 |
new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
|
175 |
if not len(new_docs):
|
176 |
+
logger.warning("All docs are already in the storage")
|
177 |
return
|
178 |
logger.info(f"[New Docs] inserting {len(new_docs)} docs")
|
179 |
|
|
|
199 |
k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
|
200 |
}
|
201 |
if not len(inserting_chunks):
|
202 |
+
logger.warning("All chunks are already in the storage")
|
203 |
return
|
204 |
logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
|
205 |
|
|
|
242 |
def query(self, query: str, param: QueryParam = QueryParam()):
|
243 |
loop = always_get_an_event_loop()
|
244 |
return loop.run_until_complete(self.aquery(query, param))
|
245 |
+
|
246 |
async def aquery(self, query: str, param: QueryParam = QueryParam()):
|
247 |
if param.mode == "local":
|
248 |
response = await local_query(
|
|
|
286 |
raise ValueError(f"Unknown mode {param.mode}")
|
287 |
await self._query_done()
|
288 |
return response
|
|
|
289 |
|
290 |
async def _query_done(self):
|
291 |
tasks = []
|
|
|
294 |
continue
|
295 |
tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
|
296 |
await asyncio.gather(*tasks)
|
|
|
|
lightrag/llm.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
import os
|
2 |
import copy
|
3 |
import json
|
4 |
-
import botocore
|
5 |
import aioboto3
|
6 |
-
import botocore.errorfactory
|
7 |
import numpy as np
|
8 |
import ollama
|
9 |
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
@@ -13,24 +11,34 @@ from tenacity import (
|
|
13 |
wait_exponential,
|
14 |
retry_if_exception_type,
|
15 |
)
|
16 |
-
from transformers import
|
17 |
import torch
|
18 |
from .base import BaseKVStorage
|
19 |
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
20 |
-
|
21 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
22 |
@retry(
|
23 |
stop=stop_after_attempt(3),
|
24 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
25 |
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
26 |
)
|
27 |
async def openai_complete_if_cache(
|
28 |
-
model,
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
) -> str:
|
30 |
if api_key:
|
31 |
os.environ["OPENAI_API_KEY"] = api_key
|
32 |
|
33 |
-
openai_async_client =
|
|
|
|
|
34 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
35 |
messages = []
|
36 |
if system_prompt:
|
@@ -64,43 +72,56 @@ class BedrockError(Exception):
|
|
64 |
retry=retry_if_exception_type((BedrockError)),
|
65 |
)
|
66 |
async def bedrock_complete_if_cache(
|
67 |
-
model,
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
) -> str:
|
70 |
-
os.environ[
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
# Fix message history format
|
75 |
messages = []
|
76 |
for history_message in history_messages:
|
77 |
message = copy.copy(history_message)
|
78 |
-
message[
|
79 |
messages.append(message)
|
80 |
|
81 |
# Add user prompt
|
82 |
-
messages.append({
|
83 |
|
84 |
# Initialize Converse API arguments
|
85 |
-
args = {
|
86 |
-
'modelId': model,
|
87 |
-
'messages': messages
|
88 |
-
}
|
89 |
|
90 |
# Define system prompt
|
91 |
if system_prompt:
|
92 |
-
args[
|
93 |
|
94 |
# Map and set up inference parameters
|
95 |
inference_params_map = {
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
}
|
100 |
-
if
|
101 |
-
|
|
|
|
|
102 |
for param in inference_params:
|
103 |
-
args[
|
|
|
|
|
104 |
|
105 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
106 |
if hashing_kv is not None:
|
@@ -112,31 +133,33 @@ async def bedrock_complete_if_cache(
|
|
112 |
# Call model via Converse API
|
113 |
session = aioboto3.Session()
|
114 |
async with session.client("bedrock-runtime") as bedrock_async_client:
|
115 |
-
|
116 |
try:
|
117 |
response = await bedrock_async_client.converse(**args, **kwargs)
|
118 |
except Exception as e:
|
119 |
raise BedrockError(e)
|
120 |
|
121 |
if hashing_kv is not None:
|
122 |
-
await hashing_kv.upsert(
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
126 |
}
|
127 |
-
|
|
|
|
|
128 |
|
129 |
-
return response['output']['message']['content'][0]['text']
|
130 |
|
131 |
async def hf_model_if_cache(
|
132 |
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
133 |
) -> str:
|
134 |
model_name = model
|
135 |
-
hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map
|
136 |
-
if hf_tokenizer.pad_token
|
137 |
# print("use eos token")
|
138 |
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
139 |
-
hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map
|
140 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
141 |
messages = []
|
142 |
if system_prompt:
|
@@ -149,30 +172,51 @@ async def hf_model_if_cache(
|
|
149 |
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
150 |
if if_cache_return is not None:
|
151 |
return if_cache_return["return"]
|
152 |
-
input_prompt =
|
153 |
try:
|
154 |
-
input_prompt = hf_tokenizer.apply_chat_template(
|
155 |
-
|
|
|
|
|
156 |
try:
|
157 |
ori_message = copy.deepcopy(messages)
|
158 |
-
if messages[0][
|
159 |
-
messages[1][
|
|
|
|
|
|
|
|
|
|
|
160 |
messages = messages[1:]
|
161 |
-
input_prompt = hf_tokenizer.apply_chat_template(
|
162 |
-
|
|
|
|
|
163 |
len_message = len(ori_message)
|
164 |
for msgid in range(len_message):
|
165 |
-
input_prompt =
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
170 |
if hashing_kv is not None:
|
171 |
-
await hashing_kv.upsert(
|
172 |
-
{args_hash: {"return": response_text, "model": model}}
|
173 |
-
)
|
174 |
return response_text
|
175 |
|
|
|
176 |
async def ollama_model_if_cache(
|
177 |
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
178 |
) -> str:
|
@@ -202,6 +246,7 @@ async def ollama_model_if_cache(
|
|
202 |
|
203 |
return result
|
204 |
|
|
|
205 |
async def gpt_4o_complete(
|
206 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
207 |
) -> str:
|
@@ -241,7 +286,7 @@ async def bedrock_complete(
|
|
241 |
async def hf_model_complete(
|
242 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
243 |
) -> str:
|
244 |
-
model_name = kwargs[
|
245 |
return await hf_model_if_cache(
|
246 |
model_name,
|
247 |
prompt,
|
@@ -250,10 +295,11 @@ async def hf_model_complete(
|
|
250 |
**kwargs,
|
251 |
)
|
252 |
|
|
|
253 |
async def ollama_model_complete(
|
254 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
255 |
) -> str:
|
256 |
-
model_name = kwargs[
|
257 |
return await ollama_model_if_cache(
|
258 |
model_name,
|
259 |
prompt,
|
@@ -262,17 +308,25 @@ async def ollama_model_complete(
|
|
262 |
**kwargs,
|
263 |
)
|
264 |
|
|
|
265 |
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
266 |
@retry(
|
267 |
stop=stop_after_attempt(3),
|
268 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
269 |
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
270 |
)
|
271 |
-
async def openai_embedding(
|
|
|
|
|
|
|
|
|
|
|
272 |
if api_key:
|
273 |
os.environ["OPENAI_API_KEY"] = api_key
|
274 |
|
275 |
-
openai_async_client =
|
|
|
|
|
276 |
response = await openai_async_client.embeddings.create(
|
277 |
model=model, input=texts, encoding_format="float"
|
278 |
)
|
@@ -286,28 +340,37 @@ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-smal
|
|
286 |
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
287 |
# )
|
288 |
async def bedrock_embedding(
|
289 |
-
texts: list[str],
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
session = aioboto3.Session()
|
296 |
async with session.client("bedrock-runtime") as bedrock_async_client:
|
297 |
-
|
298 |
if (model_provider := model.split(".")[0]) == "amazon":
|
299 |
embed_texts = []
|
300 |
for text in texts:
|
301 |
if "v2" in model:
|
302 |
-
body = json.dumps(
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
|
|
|
|
307 |
elif "v1" in model:
|
308 |
-
body = json.dumps({
|
309 |
-
'inputText': text
|
310 |
-
})
|
311 |
else:
|
312 |
raise ValueError(f"Model {model} is not supported!")
|
313 |
|
@@ -315,29 +378,27 @@ async def bedrock_embedding(
|
|
315 |
modelId=model,
|
316 |
body=body,
|
317 |
accept="application/json",
|
318 |
-
contentType="application/json"
|
319 |
)
|
320 |
|
321 |
-
response_body = await response.get(
|
322 |
|
323 |
-
embed_texts.append(response_body[
|
324 |
elif model_provider == "cohere":
|
325 |
-
body = json.dumps(
|
326 |
-
|
327 |
-
|
328 |
-
'truncate': "NONE"
|
329 |
-
})
|
330 |
|
331 |
response = await bedrock_async_client.invoke_model(
|
332 |
model=model,
|
333 |
body=body,
|
334 |
accept="application/json",
|
335 |
-
contentType="application/json"
|
336 |
)
|
337 |
|
338 |
-
response_body = json.loads(response.get(
|
339 |
|
340 |
-
embed_texts = response_body[
|
341 |
else:
|
342 |
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
343 |
|
@@ -345,12 +406,15 @@ async def bedrock_embedding(
|
|
345 |
|
346 |
|
347 |
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
348 |
-
input_ids = tokenizer(
|
|
|
|
|
349 |
with torch.no_grad():
|
350 |
outputs = embed_model(input_ids)
|
351 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
352 |
return embeddings.detach().numpy()
|
353 |
|
|
|
354 |
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
355 |
embed_text = []
|
356 |
for text in texts:
|
@@ -359,11 +423,12 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
|
359 |
|
360 |
return embed_text
|
361 |
|
|
|
362 |
if __name__ == "__main__":
|
363 |
import asyncio
|
364 |
|
365 |
async def main():
|
366 |
-
result = await gpt_4o_mini_complete(
|
367 |
print(result)
|
368 |
|
369 |
asyncio.run(main())
|
|
|
1 |
import os
|
2 |
import copy
|
3 |
import json
|
|
|
4 |
import aioboto3
|
|
|
5 |
import numpy as np
|
6 |
import ollama
|
7 |
from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
|
|
|
11 |
wait_exponential,
|
12 |
retry_if_exception_type,
|
13 |
)
|
14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
15 |
import torch
|
16 |
from .base import BaseKVStorage
|
17 |
from .utils import compute_args_hash, wrap_embedding_func_with_attrs
|
18 |
+
|
19 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
20 |
+
|
21 |
+
|
22 |
@retry(
|
23 |
stop=stop_after_attempt(3),
|
24 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
25 |
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
26 |
)
|
27 |
async def openai_complete_if_cache(
|
28 |
+
model,
|
29 |
+
prompt,
|
30 |
+
system_prompt=None,
|
31 |
+
history_messages=[],
|
32 |
+
base_url=None,
|
33 |
+
api_key=None,
|
34 |
+
**kwargs,
|
35 |
) -> str:
|
36 |
if api_key:
|
37 |
os.environ["OPENAI_API_KEY"] = api_key
|
38 |
|
39 |
+
openai_async_client = (
|
40 |
+
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
41 |
+
)
|
42 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
43 |
messages = []
|
44 |
if system_prompt:
|
|
|
72 |
retry=retry_if_exception_type((BedrockError)),
|
73 |
)
|
74 |
async def bedrock_complete_if_cache(
|
75 |
+
model,
|
76 |
+
prompt,
|
77 |
+
system_prompt=None,
|
78 |
+
history_messages=[],
|
79 |
+
aws_access_key_id=None,
|
80 |
+
aws_secret_access_key=None,
|
81 |
+
aws_session_token=None,
|
82 |
+
**kwargs,
|
83 |
) -> str:
|
84 |
+
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
85 |
+
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
86 |
+
)
|
87 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
88 |
+
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
89 |
+
)
|
90 |
+
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
91 |
+
"AWS_SESSION_TOKEN", aws_session_token
|
92 |
+
)
|
93 |
|
94 |
# Fix message history format
|
95 |
messages = []
|
96 |
for history_message in history_messages:
|
97 |
message = copy.copy(history_message)
|
98 |
+
message["content"] = [{"text": message["content"]}]
|
99 |
messages.append(message)
|
100 |
|
101 |
# Add user prompt
|
102 |
+
messages.append({"role": "user", "content": [{"text": prompt}]})
|
103 |
|
104 |
# Initialize Converse API arguments
|
105 |
+
args = {"modelId": model, "messages": messages}
|
|
|
|
|
|
|
106 |
|
107 |
# Define system prompt
|
108 |
if system_prompt:
|
109 |
+
args["system"] = [{"text": system_prompt}]
|
110 |
|
111 |
# Map and set up inference parameters
|
112 |
inference_params_map = {
|
113 |
+
"max_tokens": "maxTokens",
|
114 |
+
"top_p": "topP",
|
115 |
+
"stop_sequences": "stopSequences",
|
116 |
}
|
117 |
+
if inference_params := list(
|
118 |
+
set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
|
119 |
+
):
|
120 |
+
args["inferenceConfig"] = {}
|
121 |
for param in inference_params:
|
122 |
+
args["inferenceConfig"][inference_params_map.get(param, param)] = (
|
123 |
+
kwargs.pop(param)
|
124 |
+
)
|
125 |
|
126 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
127 |
if hashing_kv is not None:
|
|
|
133 |
# Call model via Converse API
|
134 |
session = aioboto3.Session()
|
135 |
async with session.client("bedrock-runtime") as bedrock_async_client:
|
|
|
136 |
try:
|
137 |
response = await bedrock_async_client.converse(**args, **kwargs)
|
138 |
except Exception as e:
|
139 |
raise BedrockError(e)
|
140 |
|
141 |
if hashing_kv is not None:
|
142 |
+
await hashing_kv.upsert(
|
143 |
+
{
|
144 |
+
args_hash: {
|
145 |
+
"return": response["output"]["message"]["content"][0]["text"],
|
146 |
+
"model": model,
|
147 |
+
}
|
148 |
}
|
149 |
+
)
|
150 |
+
|
151 |
+
return response["output"]["message"]["content"][0]["text"]
|
152 |
|
|
|
153 |
|
154 |
async def hf_model_if_cache(
|
155 |
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
156 |
) -> str:
|
157 |
model_name = model
|
158 |
+
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
|
159 |
+
if hf_tokenizer.pad_token is None:
|
160 |
# print("use eos token")
|
161 |
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
162 |
+
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
163 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
164 |
messages = []
|
165 |
if system_prompt:
|
|
|
172 |
if_cache_return = await hashing_kv.get_by_id(args_hash)
|
173 |
if if_cache_return is not None:
|
174 |
return if_cache_return["return"]
|
175 |
+
input_prompt = ""
|
176 |
try:
|
177 |
+
input_prompt = hf_tokenizer.apply_chat_template(
|
178 |
+
messages, tokenize=False, add_generation_prompt=True
|
179 |
+
)
|
180 |
+
except Exception:
|
181 |
try:
|
182 |
ori_message = copy.deepcopy(messages)
|
183 |
+
if messages[0]["role"] == "system":
|
184 |
+
messages[1]["content"] = (
|
185 |
+
"<system>"
|
186 |
+
+ messages[0]["content"]
|
187 |
+
+ "</system>\n"
|
188 |
+
+ messages[1]["content"]
|
189 |
+
)
|
190 |
messages = messages[1:]
|
191 |
+
input_prompt = hf_tokenizer.apply_chat_template(
|
192 |
+
messages, tokenize=False, add_generation_prompt=True
|
193 |
+
)
|
194 |
+
except Exception:
|
195 |
len_message = len(ori_message)
|
196 |
for msgid in range(len_message):
|
197 |
+
input_prompt = (
|
198 |
+
input_prompt
|
199 |
+
+ "<"
|
200 |
+
+ ori_message[msgid]["role"]
|
201 |
+
+ ">"
|
202 |
+
+ ori_message[msgid]["content"]
|
203 |
+
+ "</"
|
204 |
+
+ ori_message[msgid]["role"]
|
205 |
+
+ ">\n"
|
206 |
+
)
|
207 |
+
|
208 |
+
input_ids = hf_tokenizer(
|
209 |
+
input_prompt, return_tensors="pt", padding=True, truncation=True
|
210 |
+
).to("cuda")
|
211 |
+
output = hf_model.generate(
|
212 |
+
**input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
|
213 |
+
)
|
214 |
response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
|
215 |
if hashing_kv is not None:
|
216 |
+
await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
|
|
|
|
|
217 |
return response_text
|
218 |
|
219 |
+
|
220 |
async def ollama_model_if_cache(
|
221 |
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
222 |
) -> str:
|
|
|
246 |
|
247 |
return result
|
248 |
|
249 |
+
|
250 |
async def gpt_4o_complete(
|
251 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
252 |
) -> str:
|
|
|
286 |
async def hf_model_complete(
|
287 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
288 |
) -> str:
|
289 |
+
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
290 |
return await hf_model_if_cache(
|
291 |
model_name,
|
292 |
prompt,
|
|
|
295 |
**kwargs,
|
296 |
)
|
297 |
|
298 |
+
|
299 |
async def ollama_model_complete(
|
300 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
301 |
) -> str:
|
302 |
+
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
303 |
return await ollama_model_if_cache(
|
304 |
model_name,
|
305 |
prompt,
|
|
|
308 |
**kwargs,
|
309 |
)
|
310 |
|
311 |
+
|
312 |
@wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
|
313 |
@retry(
|
314 |
stop=stop_after_attempt(3),
|
315 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
316 |
retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
|
317 |
)
|
318 |
+
async def openai_embedding(
|
319 |
+
texts: list[str],
|
320 |
+
model: str = "text-embedding-3-small",
|
321 |
+
base_url: str = None,
|
322 |
+
api_key: str = None,
|
323 |
+
) -> np.ndarray:
|
324 |
if api_key:
|
325 |
os.environ["OPENAI_API_KEY"] = api_key
|
326 |
|
327 |
+
openai_async_client = (
|
328 |
+
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
329 |
+
)
|
330 |
response = await openai_async_client.embeddings.create(
|
331 |
model=model, input=texts, encoding_format="float"
|
332 |
)
|
|
|
340 |
# retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
|
341 |
# )
|
342 |
async def bedrock_embedding(
|
343 |
+
texts: list[str],
|
344 |
+
model: str = "amazon.titan-embed-text-v2:0",
|
345 |
+
aws_access_key_id=None,
|
346 |
+
aws_secret_access_key=None,
|
347 |
+
aws_session_token=None,
|
348 |
+
) -> np.ndarray:
|
349 |
+
os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
|
350 |
+
"AWS_ACCESS_KEY_ID", aws_access_key_id
|
351 |
+
)
|
352 |
+
os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
|
353 |
+
"AWS_SECRET_ACCESS_KEY", aws_secret_access_key
|
354 |
+
)
|
355 |
+
os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
|
356 |
+
"AWS_SESSION_TOKEN", aws_session_token
|
357 |
+
)
|
358 |
|
359 |
session = aioboto3.Session()
|
360 |
async with session.client("bedrock-runtime") as bedrock_async_client:
|
|
|
361 |
if (model_provider := model.split(".")[0]) == "amazon":
|
362 |
embed_texts = []
|
363 |
for text in texts:
|
364 |
if "v2" in model:
|
365 |
+
body = json.dumps(
|
366 |
+
{
|
367 |
+
"inputText": text,
|
368 |
+
# 'dimensions': embedding_dim,
|
369 |
+
"embeddingTypes": ["float"],
|
370 |
+
}
|
371 |
+
)
|
372 |
elif "v1" in model:
|
373 |
+
body = json.dumps({"inputText": text})
|
|
|
|
|
374 |
else:
|
375 |
raise ValueError(f"Model {model} is not supported!")
|
376 |
|
|
|
378 |
modelId=model,
|
379 |
body=body,
|
380 |
accept="application/json",
|
381 |
+
contentType="application/json",
|
382 |
)
|
383 |
|
384 |
+
response_body = await response.get("body").json()
|
385 |
|
386 |
+
embed_texts.append(response_body["embedding"])
|
387 |
elif model_provider == "cohere":
|
388 |
+
body = json.dumps(
|
389 |
+
{"texts": texts, "input_type": "search_document", "truncate": "NONE"}
|
390 |
+
)
|
|
|
|
|
391 |
|
392 |
response = await bedrock_async_client.invoke_model(
|
393 |
model=model,
|
394 |
body=body,
|
395 |
accept="application/json",
|
396 |
+
contentType="application/json",
|
397 |
)
|
398 |
|
399 |
+
response_body = json.loads(response.get("body").read())
|
400 |
|
401 |
+
embed_texts = response_body["embeddings"]
|
402 |
else:
|
403 |
raise ValueError(f"Model provider '{model_provider}' is not supported!")
|
404 |
|
|
|
406 |
|
407 |
|
408 |
async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
|
409 |
+
input_ids = tokenizer(
|
410 |
+
texts, return_tensors="pt", padding=True, truncation=True
|
411 |
+
).input_ids
|
412 |
with torch.no_grad():
|
413 |
outputs = embed_model(input_ids)
|
414 |
embeddings = outputs.last_hidden_state.mean(dim=1)
|
415 |
return embeddings.detach().numpy()
|
416 |
|
417 |
+
|
418 |
async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
|
419 |
embed_text = []
|
420 |
for text in texts:
|
|
|
423 |
|
424 |
return embed_text
|
425 |
|
426 |
+
|
427 |
if __name__ == "__main__":
|
428 |
import asyncio
|
429 |
|
430 |
async def main():
|
431 |
+
result = await gpt_4o_mini_complete("How are you?")
|
432 |
print(result)
|
433 |
|
434 |
asyncio.run(main())
|
lightrag/operate.py
CHANGED
@@ -25,6 +25,7 @@ from .base import (
|
|
25 |
)
|
26 |
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
27 |
|
|
|
28 |
def chunking_by_token_size(
|
29 |
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
30 |
):
|
@@ -45,6 +46,7 @@ def chunking_by_token_size(
|
|
45 |
)
|
46 |
return results
|
47 |
|
|
|
48 |
async def _handle_entity_relation_summary(
|
49 |
entity_or_relation_name: str,
|
50 |
description: str,
|
@@ -229,9 +231,10 @@ async def _merge_edges_then_upsert(
|
|
229 |
description=description,
|
230 |
keywords=keywords,
|
231 |
)
|
232 |
-
|
233 |
return edge_data
|
234 |
|
|
|
235 |
async def extract_entities(
|
236 |
chunks: dict[str, TextChunkSchema],
|
237 |
knwoledge_graph_inst: BaseGraphStorage,
|
@@ -352,7 +355,9 @@ async def extract_entities(
|
|
352 |
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
353 |
return None
|
354 |
if not len(all_relationships_data):
|
355 |
-
logger.warning(
|
|
|
|
|
356 |
return None
|
357 |
|
358 |
if entity_vdb is not None:
|
@@ -370,7 +375,10 @@ async def extract_entities(
|
|
370 |
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
371 |
"src_id": dp["src_id"],
|
372 |
"tgt_id": dp["tgt_id"],
|
373 |
-
"content": dp["keywords"]
|
|
|
|
|
|
|
374 |
}
|
375 |
for dp in all_relationships_data
|
376 |
}
|
@@ -378,6 +386,7 @@ async def extract_entities(
|
|
378 |
|
379 |
return knwoledge_graph_inst
|
380 |
|
|
|
381 |
async def local_query(
|
382 |
query,
|
383 |
knowledge_graph_inst: BaseGraphStorage,
|
@@ -393,19 +402,24 @@ async def local_query(
|
|
393 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
394 |
kw_prompt = kw_prompt_temp.format(query=query)
|
395 |
result = await use_model_func(kw_prompt)
|
396 |
-
|
397 |
try:
|
398 |
keywords_data = json.loads(result)
|
399 |
keywords = keywords_data.get("low_level_keywords", [])
|
400 |
-
keywords =
|
401 |
-
except json.JSONDecodeError
|
402 |
try:
|
403 |
-
result =
|
404 |
-
|
|
|
|
|
|
|
|
|
|
|
405 |
|
406 |
keywords_data = json.loads(result)
|
407 |
keywords = keywords_data.get("low_level_keywords", [])
|
408 |
-
keywords =
|
409 |
# Handle parsing error
|
410 |
except json.JSONDecodeError as e:
|
411 |
print(f"JSON parsing error: {e}")
|
@@ -430,11 +444,20 @@ async def local_query(
|
|
430 |
query,
|
431 |
system_prompt=sys_prompt,
|
432 |
)
|
433 |
-
if len(response)>len(sys_prompt):
|
434 |
-
response =
|
435 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
436 |
return response
|
437 |
|
|
|
438 |
async def _build_local_query_context(
|
439 |
query,
|
440 |
knowledge_graph_inst: BaseGraphStorage,
|
@@ -516,6 +539,7 @@ async def _build_local_query_context(
|
|
516 |
```
|
517 |
"""
|
518 |
|
|
|
519 |
async def _find_most_related_text_unit_from_entities(
|
520 |
node_datas: list[dict],
|
521 |
query_param: QueryParam,
|
@@ -576,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
|
|
576 |
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
577 |
return all_text_units
|
578 |
|
|
|
579 |
async def _find_most_related_edges_from_entities(
|
580 |
node_datas: list[dict],
|
581 |
query_param: QueryParam,
|
@@ -609,6 +634,7 @@ async def _find_most_related_edges_from_entities(
|
|
609 |
)
|
610 |
return all_edges_data
|
611 |
|
|
|
612 |
async def global_query(
|
613 |
query,
|
614 |
knowledge_graph_inst: BaseGraphStorage,
|
@@ -624,20 +650,25 @@ async def global_query(
|
|
624 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
625 |
kw_prompt = kw_prompt_temp.format(query=query)
|
626 |
result = await use_model_func(kw_prompt)
|
627 |
-
|
628 |
try:
|
629 |
keywords_data = json.loads(result)
|
630 |
keywords = keywords_data.get("high_level_keywords", [])
|
631 |
-
keywords =
|
632 |
-
except json.JSONDecodeError
|
633 |
try:
|
634 |
-
result =
|
635 |
-
|
|
|
|
|
|
|
|
|
|
|
636 |
|
637 |
keywords_data = json.loads(result)
|
638 |
keywords = keywords_data.get("high_level_keywords", [])
|
639 |
-
keywords =
|
640 |
-
|
641 |
except json.JSONDecodeError as e:
|
642 |
# Handle parsing error
|
643 |
print(f"JSON parsing error: {e}")
|
@@ -651,12 +682,12 @@ async def global_query(
|
|
651 |
text_chunks_db,
|
652 |
query_param,
|
653 |
)
|
654 |
-
|
655 |
if query_param.only_need_context:
|
656 |
return context
|
657 |
if context is None:
|
658 |
return PROMPTS["fail_response"]
|
659 |
-
|
660 |
sys_prompt_temp = PROMPTS["rag_response"]
|
661 |
sys_prompt = sys_prompt_temp.format(
|
662 |
context_data=context, response_type=query_param.response_type
|
@@ -665,11 +696,20 @@ async def global_query(
|
|
665 |
query,
|
666 |
system_prompt=sys_prompt,
|
667 |
)
|
668 |
-
if len(response)>len(sys_prompt):
|
669 |
-
response =
|
670 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
671 |
return response
|
672 |
|
|
|
673 |
async def _build_global_query_context(
|
674 |
keywords,
|
675 |
knowledge_graph_inst: BaseGraphStorage,
|
@@ -679,14 +719,14 @@ async def _build_global_query_context(
|
|
679 |
query_param: QueryParam,
|
680 |
):
|
681 |
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
682 |
-
|
683 |
if not len(results):
|
684 |
return None
|
685 |
-
|
686 |
edge_datas = await asyncio.gather(
|
687 |
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
688 |
)
|
689 |
-
|
690 |
if not all([n is not None for n in edge_datas]):
|
691 |
logger.warning("Some edges are missing, maybe the storage is damaged")
|
692 |
edge_degree = await asyncio.gather(
|
@@ -765,6 +805,7 @@ async def _build_global_query_context(
|
|
765 |
```
|
766 |
"""
|
767 |
|
|
|
768 |
async def _find_most_related_entities_from_relationships(
|
769 |
edge_datas: list[dict],
|
770 |
query_param: QueryParam,
|
@@ -774,7 +815,7 @@ async def _find_most_related_entities_from_relationships(
|
|
774 |
for e in edge_datas:
|
775 |
entity_names.add(e["src_id"])
|
776 |
entity_names.add(e["tgt_id"])
|
777 |
-
|
778 |
node_datas = await asyncio.gather(
|
779 |
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
780 |
)
|
@@ -795,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
|
|
795 |
|
796 |
return node_datas
|
797 |
|
|
|
798 |
async def _find_related_text_unit_from_relationships(
|
799 |
edge_datas: list[dict],
|
800 |
query_param: QueryParam,
|
801 |
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
802 |
knowledge_graph_inst: BaseGraphStorage,
|
803 |
):
|
804 |
-
|
805 |
text_units = [
|
806 |
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
807 |
for dp in edge_datas
|
@@ -816,15 +857,13 @@ async def _find_related_text_unit_from_relationships(
|
|
816 |
"data": await text_chunks_db.get_by_id(c_id),
|
817 |
"order": index,
|
818 |
}
|
819 |
-
|
820 |
if any([v is None for v in all_text_units_lookup.values()]):
|
821 |
logger.warning("Text chunks are missing, maybe the storage is damaged")
|
822 |
all_text_units = [
|
823 |
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
824 |
]
|
825 |
-
all_text_units = sorted(
|
826 |
-
all_text_units, key=lambda x: x["order"]
|
827 |
-
)
|
828 |
all_text_units = truncate_list_by_token_size(
|
829 |
all_text_units,
|
830 |
key=lambda x: x["data"]["content"],
|
@@ -834,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
|
|
834 |
|
835 |
return all_text_units
|
836 |
|
|
|
837 |
async def hybrid_query(
|
838 |
query,
|
839 |
knowledge_graph_inst: BaseGraphStorage,
|
@@ -849,24 +889,29 @@ async def hybrid_query(
|
|
849 |
|
850 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
851 |
kw_prompt = kw_prompt_temp.format(query=query)
|
852 |
-
|
853 |
result = await use_model_func(kw_prompt)
|
854 |
try:
|
855 |
keywords_data = json.loads(result)
|
856 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
857 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
858 |
-
hl_keywords =
|
859 |
-
ll_keywords =
|
860 |
-
except json.JSONDecodeError
|
861 |
try:
|
862 |
-
result =
|
863 |
-
|
|
|
|
|
|
|
|
|
|
|
864 |
|
865 |
keywords_data = json.loads(result)
|
866 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
867 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
868 |
-
hl_keywords =
|
869 |
-
ll_keywords =
|
870 |
# Handle parsing error
|
871 |
except json.JSONDecodeError as e:
|
872 |
print(f"JSON parsing error: {e}")
|
@@ -897,7 +942,7 @@ async def hybrid_query(
|
|
897 |
return context
|
898 |
if context is None:
|
899 |
return PROMPTS["fail_response"]
|
900 |
-
|
901 |
sys_prompt_temp = PROMPTS["rag_response"]
|
902 |
sys_prompt = sys_prompt_temp.format(
|
903 |
context_data=context, response_type=query_param.response_type
|
@@ -906,53 +951,78 @@ async def hybrid_query(
|
|
906 |
query,
|
907 |
system_prompt=sys_prompt,
|
908 |
)
|
909 |
-
if len(response)>len(sys_prompt):
|
910 |
-
response =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
911 |
return response
|
912 |
|
|
|
913 |
def combine_contexts(high_level_context, low_level_context):
|
914 |
# Function to extract entities, relationships, and sources from context strings
|
915 |
|
916 |
def extract_sections(context):
|
917 |
-
entities_match = re.search(
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
-
|
924 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
925 |
return entities, relationships, sources
|
926 |
-
|
927 |
# Extract sections from both contexts
|
928 |
|
929 |
-
if high_level_context
|
930 |
-
warnings.warn(
|
931 |
-
|
|
|
|
|
932 |
else:
|
933 |
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
934 |
|
935 |
-
|
936 |
-
|
937 |
-
|
938 |
-
|
|
|
939 |
else:
|
940 |
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
941 |
|
942 |
-
|
943 |
-
|
944 |
# Combine and deduplicate the entities
|
945 |
-
combined_entities_set = set(
|
946 |
-
|
947 |
-
|
|
|
|
|
948 |
# Combine and deduplicate the relationships
|
949 |
-
combined_relationships_set = set(
|
950 |
-
|
951 |
-
|
|
|
|
|
|
|
|
|
|
|
952 |
# Combine and deduplicate the sources
|
953 |
-
combined_sources_set = set(
|
954 |
-
|
955 |
-
|
|
|
|
|
956 |
# Format the combined context
|
957 |
return f"""
|
958 |
-----Entities-----
|
@@ -964,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
|
|
964 |
{combined_sources}
|
965 |
"""
|
966 |
|
|
|
967 |
async def naive_query(
|
968 |
query,
|
969 |
chunks_vdb: BaseVectorStorage,
|
@@ -996,8 +1067,16 @@ async def naive_query(
|
|
996 |
system_prompt=sys_prompt,
|
997 |
)
|
998 |
|
999 |
-
if len(response)>len(sys_prompt):
|
1000 |
-
response =
|
1001 |
-
|
1002 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1003 |
|
|
|
|
25 |
)
|
26 |
from .prompt import GRAPH_FIELD_SEP, PROMPTS
|
27 |
|
28 |
+
|
29 |
def chunking_by_token_size(
|
30 |
content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
|
31 |
):
|
|
|
46 |
)
|
47 |
return results
|
48 |
|
49 |
+
|
50 |
async def _handle_entity_relation_summary(
|
51 |
entity_or_relation_name: str,
|
52 |
description: str,
|
|
|
231 |
description=description,
|
232 |
keywords=keywords,
|
233 |
)
|
234 |
+
|
235 |
return edge_data
|
236 |
|
237 |
+
|
238 |
async def extract_entities(
|
239 |
chunks: dict[str, TextChunkSchema],
|
240 |
knwoledge_graph_inst: BaseGraphStorage,
|
|
|
355 |
logger.warning("Didn't extract any entities, maybe your LLM is not working")
|
356 |
return None
|
357 |
if not len(all_relationships_data):
|
358 |
+
logger.warning(
|
359 |
+
"Didn't extract any relationships, maybe your LLM is not working"
|
360 |
+
)
|
361 |
return None
|
362 |
|
363 |
if entity_vdb is not None:
|
|
|
375 |
compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
|
376 |
"src_id": dp["src_id"],
|
377 |
"tgt_id": dp["tgt_id"],
|
378 |
+
"content": dp["keywords"]
|
379 |
+
+ dp["src_id"]
|
380 |
+
+ dp["tgt_id"]
|
381 |
+
+ dp["description"],
|
382 |
}
|
383 |
for dp in all_relationships_data
|
384 |
}
|
|
|
386 |
|
387 |
return knwoledge_graph_inst
|
388 |
|
389 |
+
|
390 |
async def local_query(
|
391 |
query,
|
392 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
402 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
403 |
kw_prompt = kw_prompt_temp.format(query=query)
|
404 |
result = await use_model_func(kw_prompt)
|
405 |
+
|
406 |
try:
|
407 |
keywords_data = json.loads(result)
|
408 |
keywords = keywords_data.get("low_level_keywords", [])
|
409 |
+
keywords = ", ".join(keywords)
|
410 |
+
except json.JSONDecodeError:
|
411 |
try:
|
412 |
+
result = (
|
413 |
+
result.replace(kw_prompt[:-1], "")
|
414 |
+
.replace("user", "")
|
415 |
+
.replace("model", "")
|
416 |
+
.strip()
|
417 |
+
)
|
418 |
+
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
419 |
|
420 |
keywords_data = json.loads(result)
|
421 |
keywords = keywords_data.get("low_level_keywords", [])
|
422 |
+
keywords = ", ".join(keywords)
|
423 |
# Handle parsing error
|
424 |
except json.JSONDecodeError as e:
|
425 |
print(f"JSON parsing error: {e}")
|
|
|
444 |
query,
|
445 |
system_prompt=sys_prompt,
|
446 |
)
|
447 |
+
if len(response) > len(sys_prompt):
|
448 |
+
response = (
|
449 |
+
response.replace(sys_prompt, "")
|
450 |
+
.replace("user", "")
|
451 |
+
.replace("model", "")
|
452 |
+
.replace(query, "")
|
453 |
+
.replace("<system>", "")
|
454 |
+
.replace("</system>", "")
|
455 |
+
.strip()
|
456 |
+
)
|
457 |
+
|
458 |
return response
|
459 |
|
460 |
+
|
461 |
async def _build_local_query_context(
|
462 |
query,
|
463 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
539 |
```
|
540 |
"""
|
541 |
|
542 |
+
|
543 |
async def _find_most_related_text_unit_from_entities(
|
544 |
node_datas: list[dict],
|
545 |
query_param: QueryParam,
|
|
|
600 |
all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
|
601 |
return all_text_units
|
602 |
|
603 |
+
|
604 |
async def _find_most_related_edges_from_entities(
|
605 |
node_datas: list[dict],
|
606 |
query_param: QueryParam,
|
|
|
634 |
)
|
635 |
return all_edges_data
|
636 |
|
637 |
+
|
638 |
async def global_query(
|
639 |
query,
|
640 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
650 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
651 |
kw_prompt = kw_prompt_temp.format(query=query)
|
652 |
result = await use_model_func(kw_prompt)
|
653 |
+
|
654 |
try:
|
655 |
keywords_data = json.loads(result)
|
656 |
keywords = keywords_data.get("high_level_keywords", [])
|
657 |
+
keywords = ", ".join(keywords)
|
658 |
+
except json.JSONDecodeError:
|
659 |
try:
|
660 |
+
result = (
|
661 |
+
result.replace(kw_prompt[:-1], "")
|
662 |
+
.replace("user", "")
|
663 |
+
.replace("model", "")
|
664 |
+
.strip()
|
665 |
+
)
|
666 |
+
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
667 |
|
668 |
keywords_data = json.loads(result)
|
669 |
keywords = keywords_data.get("high_level_keywords", [])
|
670 |
+
keywords = ", ".join(keywords)
|
671 |
+
|
672 |
except json.JSONDecodeError as e:
|
673 |
# Handle parsing error
|
674 |
print(f"JSON parsing error: {e}")
|
|
|
682 |
text_chunks_db,
|
683 |
query_param,
|
684 |
)
|
685 |
+
|
686 |
if query_param.only_need_context:
|
687 |
return context
|
688 |
if context is None:
|
689 |
return PROMPTS["fail_response"]
|
690 |
+
|
691 |
sys_prompt_temp = PROMPTS["rag_response"]
|
692 |
sys_prompt = sys_prompt_temp.format(
|
693 |
context_data=context, response_type=query_param.response_type
|
|
|
696 |
query,
|
697 |
system_prompt=sys_prompt,
|
698 |
)
|
699 |
+
if len(response) > len(sys_prompt):
|
700 |
+
response = (
|
701 |
+
response.replace(sys_prompt, "")
|
702 |
+
.replace("user", "")
|
703 |
+
.replace("model", "")
|
704 |
+
.replace(query, "")
|
705 |
+
.replace("<system>", "")
|
706 |
+
.replace("</system>", "")
|
707 |
+
.strip()
|
708 |
+
)
|
709 |
+
|
710 |
return response
|
711 |
|
712 |
+
|
713 |
async def _build_global_query_context(
|
714 |
keywords,
|
715 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
719 |
query_param: QueryParam,
|
720 |
):
|
721 |
results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
|
722 |
+
|
723 |
if not len(results):
|
724 |
return None
|
725 |
+
|
726 |
edge_datas = await asyncio.gather(
|
727 |
*[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
|
728 |
)
|
729 |
+
|
730 |
if not all([n is not None for n in edge_datas]):
|
731 |
logger.warning("Some edges are missing, maybe the storage is damaged")
|
732 |
edge_degree = await asyncio.gather(
|
|
|
805 |
```
|
806 |
"""
|
807 |
|
808 |
+
|
809 |
async def _find_most_related_entities_from_relationships(
|
810 |
edge_datas: list[dict],
|
811 |
query_param: QueryParam,
|
|
|
815 |
for e in edge_datas:
|
816 |
entity_names.add(e["src_id"])
|
817 |
entity_names.add(e["tgt_id"])
|
818 |
+
|
819 |
node_datas = await asyncio.gather(
|
820 |
*[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
|
821 |
)
|
|
|
836 |
|
837 |
return node_datas
|
838 |
|
839 |
+
|
840 |
async def _find_related_text_unit_from_relationships(
|
841 |
edge_datas: list[dict],
|
842 |
query_param: QueryParam,
|
843 |
text_chunks_db: BaseKVStorage[TextChunkSchema],
|
844 |
knowledge_graph_inst: BaseGraphStorage,
|
845 |
):
|
|
|
846 |
text_units = [
|
847 |
split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
|
848 |
for dp in edge_datas
|
|
|
857 |
"data": await text_chunks_db.get_by_id(c_id),
|
858 |
"order": index,
|
859 |
}
|
860 |
+
|
861 |
if any([v is None for v in all_text_units_lookup.values()]):
|
862 |
logger.warning("Text chunks are missing, maybe the storage is damaged")
|
863 |
all_text_units = [
|
864 |
{"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
|
865 |
]
|
866 |
+
all_text_units = sorted(all_text_units, key=lambda x: x["order"])
|
|
|
|
|
867 |
all_text_units = truncate_list_by_token_size(
|
868 |
all_text_units,
|
869 |
key=lambda x: x["data"]["content"],
|
|
|
873 |
|
874 |
return all_text_units
|
875 |
|
876 |
+
|
877 |
async def hybrid_query(
|
878 |
query,
|
879 |
knowledge_graph_inst: BaseGraphStorage,
|
|
|
889 |
|
890 |
kw_prompt_temp = PROMPTS["keywords_extraction"]
|
891 |
kw_prompt = kw_prompt_temp.format(query=query)
|
892 |
+
|
893 |
result = await use_model_func(kw_prompt)
|
894 |
try:
|
895 |
keywords_data = json.loads(result)
|
896 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
897 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
898 |
+
hl_keywords = ", ".join(hl_keywords)
|
899 |
+
ll_keywords = ", ".join(ll_keywords)
|
900 |
+
except json.JSONDecodeError:
|
901 |
try:
|
902 |
+
result = (
|
903 |
+
result.replace(kw_prompt[:-1], "")
|
904 |
+
.replace("user", "")
|
905 |
+
.replace("model", "")
|
906 |
+
.strip()
|
907 |
+
)
|
908 |
+
result = "{" + result.split("{")[1].split("}")[0] + "}"
|
909 |
|
910 |
keywords_data = json.loads(result)
|
911 |
hl_keywords = keywords_data.get("high_level_keywords", [])
|
912 |
ll_keywords = keywords_data.get("low_level_keywords", [])
|
913 |
+
hl_keywords = ", ".join(hl_keywords)
|
914 |
+
ll_keywords = ", ".join(ll_keywords)
|
915 |
# Handle parsing error
|
916 |
except json.JSONDecodeError as e:
|
917 |
print(f"JSON parsing error: {e}")
|
|
|
942 |
return context
|
943 |
if context is None:
|
944 |
return PROMPTS["fail_response"]
|
945 |
+
|
946 |
sys_prompt_temp = PROMPTS["rag_response"]
|
947 |
sys_prompt = sys_prompt_temp.format(
|
948 |
context_data=context, response_type=query_param.response_type
|
|
|
951 |
query,
|
952 |
system_prompt=sys_prompt,
|
953 |
)
|
954 |
+
if len(response) > len(sys_prompt):
|
955 |
+
response = (
|
956 |
+
response.replace(sys_prompt, "")
|
957 |
+
.replace("user", "")
|
958 |
+
.replace("model", "")
|
959 |
+
.replace(query, "")
|
960 |
+
.replace("<system>", "")
|
961 |
+
.replace("</system>", "")
|
962 |
+
.strip()
|
963 |
+
)
|
964 |
return response
|
965 |
|
966 |
+
|
967 |
def combine_contexts(high_level_context, low_level_context):
|
968 |
# Function to extract entities, relationships, and sources from context strings
|
969 |
|
970 |
def extract_sections(context):
|
971 |
+
entities_match = re.search(
|
972 |
+
r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
973 |
+
)
|
974 |
+
relationships_match = re.search(
|
975 |
+
r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
976 |
+
)
|
977 |
+
sources_match = re.search(
|
978 |
+
r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
|
979 |
+
)
|
980 |
+
|
981 |
+
entities = entities_match.group(1) if entities_match else ""
|
982 |
+
relationships = relationships_match.group(1) if relationships_match else ""
|
983 |
+
sources = sources_match.group(1) if sources_match else ""
|
984 |
+
|
985 |
return entities, relationships, sources
|
986 |
+
|
987 |
# Extract sections from both contexts
|
988 |
|
989 |
+
if high_level_context is None:
|
990 |
+
warnings.warn(
|
991 |
+
"High Level context is None. Return empty High entity/relationship/source"
|
992 |
+
)
|
993 |
+
hl_entities, hl_relationships, hl_sources = "", "", ""
|
994 |
else:
|
995 |
hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
|
996 |
|
997 |
+
if low_level_context is None:
|
998 |
+
warnings.warn(
|
999 |
+
"Low Level context is None. Return empty Low entity/relationship/source"
|
1000 |
+
)
|
1001 |
+
ll_entities, ll_relationships, ll_sources = "", "", ""
|
1002 |
else:
|
1003 |
ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
|
1004 |
|
|
|
|
|
1005 |
# Combine and deduplicate the entities
|
1006 |
+
combined_entities_set = set(
|
1007 |
+
filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
|
1008 |
+
)
|
1009 |
+
combined_entities = "\n".join(combined_entities_set)
|
1010 |
+
|
1011 |
# Combine and deduplicate the relationships
|
1012 |
+
combined_relationships_set = set(
|
1013 |
+
filter(
|
1014 |
+
None,
|
1015 |
+
hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
|
1016 |
+
)
|
1017 |
+
)
|
1018 |
+
combined_relationships = "\n".join(combined_relationships_set)
|
1019 |
+
|
1020 |
# Combine and deduplicate the sources
|
1021 |
+
combined_sources_set = set(
|
1022 |
+
filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
|
1023 |
+
)
|
1024 |
+
combined_sources = "\n".join(combined_sources_set)
|
1025 |
+
|
1026 |
# Format the combined context
|
1027 |
return f"""
|
1028 |
-----Entities-----
|
|
|
1034 |
{combined_sources}
|
1035 |
"""
|
1036 |
|
1037 |
+
|
1038 |
async def naive_query(
|
1039 |
query,
|
1040 |
chunks_vdb: BaseVectorStorage,
|
|
|
1067 |
system_prompt=sys_prompt,
|
1068 |
)
|
1069 |
|
1070 |
+
if len(response) > len(sys_prompt):
|
1071 |
+
response = (
|
1072 |
+
response[len(sys_prompt) :]
|
1073 |
+
.replace(sys_prompt, "")
|
1074 |
+
.replace("user", "")
|
1075 |
+
.replace("model", "")
|
1076 |
+
.replace(query, "")
|
1077 |
+
.replace("<system>", "")
|
1078 |
+
.replace("</system>", "")
|
1079 |
+
.strip()
|
1080 |
+
)
|
1081 |
|
1082 |
+
return response
|
lightrag/prompt.py
CHANGED
@@ -9,9 +9,7 @@ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
|
|
9 |
|
10 |
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
|
11 |
|
12 |
-
PROMPTS[
|
13 |
-
"entity_extraction"
|
14 |
-
] = """-Goal-
|
15 |
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
16 |
|
17 |
-Steps-
|
@@ -32,7 +30,7 @@ Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tupl
|
|
32 |
|
33 |
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
|
34 |
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
|
35 |
-
|
36 |
4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
37 |
|
38 |
5. When finished, output {completion_delimiter}
|
@@ -146,9 +144,7 @@ PROMPTS[
|
|
146 |
|
147 |
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
|
148 |
|
149 |
-
PROMPTS[
|
150 |
-
"rag_response"
|
151 |
-
] = """---Role---
|
152 |
|
153 |
You are a helpful assistant responding to questions about data in the tables provided.
|
154 |
|
@@ -241,9 +237,7 @@ Output:
|
|
241 |
|
242 |
"""
|
243 |
|
244 |
-
PROMPTS[
|
245 |
-
"naive_rag_response"
|
246 |
-
] = """You're a helpful assistant
|
247 |
Below are the knowledge you know:
|
248 |
{content_data}
|
249 |
---
|
|
|
9 |
|
10 |
PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
|
11 |
|
12 |
+
PROMPTS["entity_extraction"] = """-Goal-
|
|
|
|
|
13 |
Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
|
14 |
|
15 |
-Steps-
|
|
|
30 |
|
31 |
3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
|
32 |
Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
|
33 |
+
|
34 |
4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
35 |
|
36 |
5. When finished, output {completion_delimiter}
|
|
|
144 |
|
145 |
PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
|
146 |
|
147 |
+
PROMPTS["rag_response"] = """---Role---
|
|
|
|
|
148 |
|
149 |
You are a helpful assistant responding to questions about data in the tables provided.
|
150 |
|
|
|
237 |
|
238 |
"""
|
239 |
|
240 |
+
PROMPTS["naive_rag_response"] = """You're a helpful assistant
|
|
|
|
|
241 |
Below are the knowledge you know:
|
242 |
{content_data}
|
243 |
---
|
lightrag/storage.py
CHANGED
@@ -1,16 +1,11 @@
|
|
1 |
import asyncio
|
2 |
import html
|
3 |
-
import json
|
4 |
import os
|
5 |
-
from
|
6 |
-
from dataclasses import dataclass, field
|
7 |
from typing import Any, Union, cast
|
8 |
-
import pickle
|
9 |
-
import hnswlib
|
10 |
import networkx as nx
|
11 |
import numpy as np
|
12 |
from nano_vectordb import NanoVectorDB
|
13 |
-
import xxhash
|
14 |
|
15 |
from .utils import load_json, logger, write_json
|
16 |
from .base import (
|
@@ -19,6 +14,7 @@ from .base import (
|
|
19 |
BaseVectorStorage,
|
20 |
)
|
21 |
|
|
|
22 |
@dataclass
|
23 |
class JsonKVStorage(BaseKVStorage):
|
24 |
def __post_init__(self):
|
@@ -59,12 +55,12 @@ class JsonKVStorage(BaseKVStorage):
|
|
59 |
async def drop(self):
|
60 |
self._data = {}
|
61 |
|
|
|
62 |
@dataclass
|
63 |
class NanoVectorDBStorage(BaseVectorStorage):
|
64 |
cosine_better_than_threshold: float = 0.2
|
65 |
|
66 |
def __post_init__(self):
|
67 |
-
|
68 |
self._client_file_name = os.path.join(
|
69 |
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
70 |
)
|
@@ -118,6 +114,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
118 |
async def index_done_callback(self):
|
119 |
self._client.save()
|
120 |
|
|
|
121 |
@dataclass
|
122 |
class NetworkXStorage(BaseGraphStorage):
|
123 |
@staticmethod
|
@@ -142,7 +139,9 @@ class NetworkXStorage(BaseGraphStorage):
|
|
142 |
|
143 |
graph = graph.copy()
|
144 |
graph = cast(nx.Graph, largest_connected_component(graph))
|
145 |
-
node_mapping = {
|
|
|
|
|
146 |
graph = nx.relabel_nodes(graph, node_mapping)
|
147 |
return NetworkXStorage._stabilize_graph(graph)
|
148 |
|
|
|
1 |
import asyncio
|
2 |
import html
|
|
|
3 |
import os
|
4 |
+
from dataclasses import dataclass
|
|
|
5 |
from typing import Any, Union, cast
|
|
|
|
|
6 |
import networkx as nx
|
7 |
import numpy as np
|
8 |
from nano_vectordb import NanoVectorDB
|
|
|
9 |
|
10 |
from .utils import load_json, logger, write_json
|
11 |
from .base import (
|
|
|
14 |
BaseVectorStorage,
|
15 |
)
|
16 |
|
17 |
+
|
18 |
@dataclass
|
19 |
class JsonKVStorage(BaseKVStorage):
|
20 |
def __post_init__(self):
|
|
|
55 |
async def drop(self):
|
56 |
self._data = {}
|
57 |
|
58 |
+
|
59 |
@dataclass
|
60 |
class NanoVectorDBStorage(BaseVectorStorage):
|
61 |
cosine_better_than_threshold: float = 0.2
|
62 |
|
63 |
def __post_init__(self):
|
|
|
64 |
self._client_file_name = os.path.join(
|
65 |
self.global_config["working_dir"], f"vdb_{self.namespace}.json"
|
66 |
)
|
|
|
114 |
async def index_done_callback(self):
|
115 |
self._client.save()
|
116 |
|
117 |
+
|
118 |
@dataclass
|
119 |
class NetworkXStorage(BaseGraphStorage):
|
120 |
@staticmethod
|
|
|
139 |
|
140 |
graph = graph.copy()
|
141 |
graph = cast(nx.Graph, largest_connected_component(graph))
|
142 |
+
node_mapping = {
|
143 |
+
node: html.unescape(node.upper().strip()) for node in graph.nodes()
|
144 |
+
} # type: ignore
|
145 |
graph = nx.relabel_nodes(graph, node_mapping)
|
146 |
return NetworkXStorage._stabilize_graph(graph)
|
147 |
|
lightrag/utils.py
CHANGED
@@ -16,18 +16,22 @@ ENCODER = None
|
|
16 |
|
17 |
logger = logging.getLogger("lightrag")
|
18 |
|
|
|
19 |
def set_logger(log_file: str):
|
20 |
logger.setLevel(logging.DEBUG)
|
21 |
|
22 |
file_handler = logging.FileHandler(log_file)
|
23 |
file_handler.setLevel(logging.DEBUG)
|
24 |
|
25 |
-
formatter = logging.Formatter(
|
|
|
|
|
26 |
file_handler.setFormatter(formatter)
|
27 |
|
28 |
if not logger.handlers:
|
29 |
logger.addHandler(file_handler)
|
30 |
|
|
|
31 |
@dataclass
|
32 |
class EmbeddingFunc:
|
33 |
embedding_dim: int
|
@@ -36,7 +40,8 @@ class EmbeddingFunc:
|
|
36 |
|
37 |
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
38 |
return await self.func(*args, **kwargs)
|
39 |
-
|
|
|
40 |
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
41 |
"""Locate the JSON string body from a string"""
|
42 |
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
@@ -45,6 +50,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
|
45 |
else:
|
46 |
return None
|
47 |
|
|
|
48 |
def convert_response_to_json(response: str) -> dict:
|
49 |
json_str = locate_json_string_body_from_string(response)
|
50 |
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
@@ -55,12 +61,15 @@ def convert_response_to_json(response: str) -> dict:
|
|
55 |
logger.error(f"Failed to parse JSON: {json_str}")
|
56 |
raise e from None
|
57 |
|
|
|
58 |
def compute_args_hash(*args):
|
59 |
return md5(str(args).encode()).hexdigest()
|
60 |
|
|
|
61 |
def compute_mdhash_id(content, prefix: str = ""):
|
62 |
return prefix + md5(content.encode()).hexdigest()
|
63 |
|
|
|
64 |
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
65 |
"""Add restriction of maximum async calling times for a async func"""
|
66 |
|
@@ -82,6 +91,7 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
|
82 |
|
83 |
return final_decro
|
84 |
|
|
|
85 |
def wrap_embedding_func_with_attrs(**kwargs):
|
86 |
"""Wrap a function with attributes"""
|
87 |
|
@@ -91,16 +101,19 @@ def wrap_embedding_func_with_attrs(**kwargs):
|
|
91 |
|
92 |
return final_decro
|
93 |
|
|
|
94 |
def load_json(file_name):
|
95 |
if not os.path.exists(file_name):
|
96 |
return None
|
97 |
with open(file_name, encoding="utf-8") as f:
|
98 |
return json.load(f)
|
99 |
|
|
|
100 |
def write_json(json_obj, file_name):
|
101 |
with open(file_name, "w", encoding="utf-8") as f:
|
102 |
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
103 |
|
|
|
104 |
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
105 |
global ENCODER
|
106 |
if ENCODER is None:
|
@@ -116,12 +129,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
|
|
116 |
content = ENCODER.decode(tokens)
|
117 |
return content
|
118 |
|
|
|
119 |
def pack_user_ass_to_openai_messages(*args: str):
|
120 |
roles = ["user", "assistant"]
|
121 |
return [
|
122 |
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
123 |
]
|
124 |
|
|
|
125 |
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
126 |
"""Split a string by multiple markers"""
|
127 |
if not markers:
|
@@ -129,6 +144,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
|
|
129 |
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
130 |
return [r.strip() for r in results if r.strip()]
|
131 |
|
|
|
132 |
# Refer the utils functions of the official GraphRAG implementation:
|
133 |
# https://github.com/microsoft/graphrag
|
134 |
def clean_str(input: Any) -> str:
|
@@ -141,9 +157,11 @@ def clean_str(input: Any) -> str:
|
|
141 |
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
142 |
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
143 |
|
|
|
144 |
def is_float_regex(value):
|
145 |
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
146 |
|
|
|
147 |
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
148 |
"""Truncate a list of data by token size"""
|
149 |
if max_token_size <= 0:
|
@@ -155,11 +173,13 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
|
|
155 |
return list_data[:i]
|
156 |
return list_data
|
157 |
|
|
|
158 |
def list_of_list_to_csv(data: list[list]):
|
159 |
return "\n".join(
|
160 |
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
|
161 |
)
|
162 |
|
|
|
163 |
def save_data_to_file(data, file_name):
|
164 |
-
with open(file_name,
|
165 |
-
json.dump(data, f, ensure_ascii=False, indent=4)
|
|
|
16 |
|
17 |
logger = logging.getLogger("lightrag")
|
18 |
|
19 |
+
|
20 |
def set_logger(log_file: str):
|
21 |
logger.setLevel(logging.DEBUG)
|
22 |
|
23 |
file_handler = logging.FileHandler(log_file)
|
24 |
file_handler.setLevel(logging.DEBUG)
|
25 |
|
26 |
+
formatter = logging.Formatter(
|
27 |
+
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
28 |
+
)
|
29 |
file_handler.setFormatter(formatter)
|
30 |
|
31 |
if not logger.handlers:
|
32 |
logger.addHandler(file_handler)
|
33 |
|
34 |
+
|
35 |
@dataclass
|
36 |
class EmbeddingFunc:
|
37 |
embedding_dim: int
|
|
|
40 |
|
41 |
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
42 |
return await self.func(*args, **kwargs)
|
43 |
+
|
44 |
+
|
45 |
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
46 |
"""Locate the JSON string body from a string"""
|
47 |
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
|
|
50 |
else:
|
51 |
return None
|
52 |
|
53 |
+
|
54 |
def convert_response_to_json(response: str) -> dict:
|
55 |
json_str = locate_json_string_body_from_string(response)
|
56 |
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
|
|
61 |
logger.error(f"Failed to parse JSON: {json_str}")
|
62 |
raise e from None
|
63 |
|
64 |
+
|
65 |
def compute_args_hash(*args):
|
66 |
return md5(str(args).encode()).hexdigest()
|
67 |
|
68 |
+
|
69 |
def compute_mdhash_id(content, prefix: str = ""):
|
70 |
return prefix + md5(content.encode()).hexdigest()
|
71 |
|
72 |
+
|
73 |
def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
|
74 |
"""Add restriction of maximum async calling times for a async func"""
|
75 |
|
|
|
91 |
|
92 |
return final_decro
|
93 |
|
94 |
+
|
95 |
def wrap_embedding_func_with_attrs(**kwargs):
|
96 |
"""Wrap a function with attributes"""
|
97 |
|
|
|
101 |
|
102 |
return final_decro
|
103 |
|
104 |
+
|
105 |
def load_json(file_name):
|
106 |
if not os.path.exists(file_name):
|
107 |
return None
|
108 |
with open(file_name, encoding="utf-8") as f:
|
109 |
return json.load(f)
|
110 |
|
111 |
+
|
112 |
def write_json(json_obj, file_name):
|
113 |
with open(file_name, "w", encoding="utf-8") as f:
|
114 |
json.dump(json_obj, f, indent=2, ensure_ascii=False)
|
115 |
|
116 |
+
|
117 |
def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
|
118 |
global ENCODER
|
119 |
if ENCODER is None:
|
|
|
129 |
content = ENCODER.decode(tokens)
|
130 |
return content
|
131 |
|
132 |
+
|
133 |
def pack_user_ass_to_openai_messages(*args: str):
|
134 |
roles = ["user", "assistant"]
|
135 |
return [
|
136 |
{"role": roles[i % 2], "content": content} for i, content in enumerate(args)
|
137 |
]
|
138 |
|
139 |
+
|
140 |
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
141 |
"""Split a string by multiple markers"""
|
142 |
if not markers:
|
|
|
144 |
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
145 |
return [r.strip() for r in results if r.strip()]
|
146 |
|
147 |
+
|
148 |
# Refer the utils functions of the official GraphRAG implementation:
|
149 |
# https://github.com/microsoft/graphrag
|
150 |
def clean_str(input: Any) -> str:
|
|
|
157 |
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
158 |
return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
|
159 |
|
160 |
+
|
161 |
def is_float_regex(value):
|
162 |
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
163 |
|
164 |
+
|
165 |
def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
|
166 |
"""Truncate a list of data by token size"""
|
167 |
if max_token_size <= 0:
|
|
|
173 |
return list_data[:i]
|
174 |
return list_data
|
175 |
|
176 |
+
|
177 |
def list_of_list_to_csv(data: list[list]):
|
178 |
return "\n".join(
|
179 |
[",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
|
180 |
)
|
181 |
|
182 |
+
|
183 |
def save_data_to_file(data, file_name):
|
184 |
+
with open(file_name, "w", encoding="utf-8") as f:
|
185 |
+
json.dump(data, f, ensure_ascii=False, indent=4)
|
reproduce/Step_0.py
CHANGED
@@ -3,11 +3,11 @@ import json
|
|
3 |
import glob
|
4 |
import argparse
|
5 |
|
6 |
-
def extract_unique_contexts(input_directory, output_directory):
|
7 |
|
|
|
8 |
os.makedirs(output_directory, exist_ok=True)
|
9 |
|
10 |
-
jsonl_files = glob.glob(os.path.join(input_directory,
|
11 |
print(f"Found {len(jsonl_files)} JSONL files.")
|
12 |
|
13 |
for file_path in jsonl_files:
|
@@ -21,18 +21,20 @@ def extract_unique_contexts(input_directory, output_directory):
|
|
21 |
print(f"Processing file: {filename}")
|
22 |
|
23 |
try:
|
24 |
-
with open(file_path,
|
25 |
for line_number, line in enumerate(infile, start=1):
|
26 |
line = line.strip()
|
27 |
if not line:
|
28 |
continue
|
29 |
try:
|
30 |
json_obj = json.loads(line)
|
31 |
-
context = json_obj.get(
|
32 |
if context and context not in unique_contexts_dict:
|
33 |
unique_contexts_dict[context] = None
|
34 |
except json.JSONDecodeError as e:
|
35 |
-
print(
|
|
|
|
|
36 |
except FileNotFoundError:
|
37 |
print(f"File not found: {filename}")
|
38 |
continue
|
@@ -41,10 +43,12 @@ def extract_unique_contexts(input_directory, output_directory):
|
|
41 |
continue
|
42 |
|
43 |
unique_contexts_list = list(unique_contexts_dict.keys())
|
44 |
-
print(
|
|
|
|
|
45 |
|
46 |
try:
|
47 |
-
with open(output_path,
|
48 |
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
|
49 |
print(f"Unique `context` entries have been saved to: {output_filename}")
|
50 |
except Exception as e:
|
@@ -55,8 +59,10 @@ def extract_unique_contexts(input_directory, output_directory):
|
|
55 |
|
56 |
if __name__ == "__main__":
|
57 |
parser = argparse.ArgumentParser()
|
58 |
-
parser.add_argument(
|
59 |
-
parser.add_argument(
|
|
|
|
|
60 |
|
61 |
args = parser.parse_args()
|
62 |
|
|
|
3 |
import glob
|
4 |
import argparse
|
5 |
|
|
|
6 |
|
7 |
+
def extract_unique_contexts(input_directory, output_directory):
|
8 |
os.makedirs(output_directory, exist_ok=True)
|
9 |
|
10 |
+
jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl"))
|
11 |
print(f"Found {len(jsonl_files)} JSONL files.")
|
12 |
|
13 |
for file_path in jsonl_files:
|
|
|
21 |
print(f"Processing file: {filename}")
|
22 |
|
23 |
try:
|
24 |
+
with open(file_path, "r", encoding="utf-8") as infile:
|
25 |
for line_number, line in enumerate(infile, start=1):
|
26 |
line = line.strip()
|
27 |
if not line:
|
28 |
continue
|
29 |
try:
|
30 |
json_obj = json.loads(line)
|
31 |
+
context = json_obj.get("context")
|
32 |
if context and context not in unique_contexts_dict:
|
33 |
unique_contexts_dict[context] = None
|
34 |
except json.JSONDecodeError as e:
|
35 |
+
print(
|
36 |
+
f"JSON decoding error in file {filename} at line {line_number}: {e}"
|
37 |
+
)
|
38 |
except FileNotFoundError:
|
39 |
print(f"File not found: {filename}")
|
40 |
continue
|
|
|
43 |
continue
|
44 |
|
45 |
unique_contexts_list = list(unique_contexts_dict.keys())
|
46 |
+
print(
|
47 |
+
f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}."
|
48 |
+
)
|
49 |
|
50 |
try:
|
51 |
+
with open(output_path, "w", encoding="utf-8") as outfile:
|
52 |
json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
|
53 |
print(f"Unique `context` entries have been saved to: {output_filename}")
|
54 |
except Exception as e:
|
|
|
59 |
|
60 |
if __name__ == "__main__":
|
61 |
parser = argparse.ArgumentParser()
|
62 |
+
parser.add_argument("-i", "--input_dir", type=str, default="../datasets")
|
63 |
+
parser.add_argument(
|
64 |
+
"-o", "--output_dir", type=str, default="../datasets/unique_contexts"
|
65 |
+
)
|
66 |
|
67 |
args = parser.parse_args()
|
68 |
|
reproduce/Step_1.py
CHANGED
@@ -4,10 +4,11 @@ import time
|
|
4 |
|
5 |
from lightrag import LightRAG
|
6 |
|
|
|
7 |
def insert_text(rag, file_path):
|
8 |
-
with open(file_path, mode=
|
9 |
unique_contexts = json.load(f)
|
10 |
-
|
11 |
retries = 0
|
12 |
max_retries = 3
|
13 |
while retries < max_retries:
|
@@ -21,6 +22,7 @@ def insert_text(rag, file_path):
|
|
21 |
if retries == max_retries:
|
22 |
print("Insertion failed after exceeding the maximum number of retries")
|
23 |
|
|
|
24 |
cls = "agriculture"
|
25 |
WORKING_DIR = "../{cls}"
|
26 |
|
@@ -29,4 +31,4 @@ if not os.path.exists(WORKING_DIR):
|
|
29 |
|
30 |
rag = LightRAG(working_dir=WORKING_DIR)
|
31 |
|
32 |
-
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
|
|
4 |
|
5 |
from lightrag import LightRAG
|
6 |
|
7 |
+
|
8 |
def insert_text(rag, file_path):
|
9 |
+
with open(file_path, mode="r") as f:
|
10 |
unique_contexts = json.load(f)
|
11 |
+
|
12 |
retries = 0
|
13 |
max_retries = 3
|
14 |
while retries < max_retries:
|
|
|
22 |
if retries == max_retries:
|
23 |
print("Insertion failed after exceeding the maximum number of retries")
|
24 |
|
25 |
+
|
26 |
cls = "agriculture"
|
27 |
WORKING_DIR = "../{cls}"
|
28 |
|
|
|
31 |
|
32 |
rag = LightRAG(working_dir=WORKING_DIR)
|
33 |
|
34 |
+
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
reproduce/Step_1_openai_compatible.py
CHANGED
@@ -7,6 +7,7 @@ from lightrag import LightRAG
|
|
7 |
from lightrag.utils import EmbeddingFunc
|
8 |
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
9 |
|
|
|
10 |
## For Upstage API
|
11 |
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
12 |
async def llm_model_func(
|
@@ -19,22 +20,26 @@ async def llm_model_func(
|
|
19 |
history_messages=history_messages,
|
20 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
21 |
base_url="https://api.upstage.ai/v1/solar",
|
22 |
-
**kwargs
|
23 |
)
|
24 |
|
|
|
25 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
26 |
return await openai_embedding(
|
27 |
texts,
|
28 |
model="solar-embedding-1-large-query",
|
29 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
30 |
-
base_url="https://api.upstage.ai/v1/solar"
|
31 |
)
|
|
|
|
|
32 |
## /For Upstage API
|
33 |
|
|
|
34 |
def insert_text(rag, file_path):
|
35 |
-
with open(file_path, mode=
|
36 |
unique_contexts = json.load(f)
|
37 |
-
|
38 |
retries = 0
|
39 |
max_retries = 3
|
40 |
while retries < max_retries:
|
@@ -48,19 +53,19 @@ def insert_text(rag, file_path):
|
|
48 |
if retries == max_retries:
|
49 |
print("Insertion failed after exceeding the maximum number of retries")
|
50 |
|
|
|
51 |
cls = "mix"
|
52 |
WORKING_DIR = f"../{cls}"
|
53 |
|
54 |
if not os.path.exists(WORKING_DIR):
|
55 |
os.mkdir(WORKING_DIR)
|
56 |
|
57 |
-
rag = LightRAG(
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
)
|
65 |
|
66 |
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
|
|
7 |
from lightrag.utils import EmbeddingFunc
|
8 |
from lightrag.llm import openai_complete_if_cache, openai_embedding
|
9 |
|
10 |
+
|
11 |
## For Upstage API
|
12 |
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
13 |
async def llm_model_func(
|
|
|
20 |
history_messages=history_messages,
|
21 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
22 |
base_url="https://api.upstage.ai/v1/solar",
|
23 |
+
**kwargs,
|
24 |
)
|
25 |
|
26 |
+
|
27 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
28 |
return await openai_embedding(
|
29 |
texts,
|
30 |
model="solar-embedding-1-large-query",
|
31 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
32 |
+
base_url="https://api.upstage.ai/v1/solar",
|
33 |
)
|
34 |
+
|
35 |
+
|
36 |
## /For Upstage API
|
37 |
|
38 |
+
|
39 |
def insert_text(rag, file_path):
|
40 |
+
with open(file_path, mode="r") as f:
|
41 |
unique_contexts = json.load(f)
|
42 |
+
|
43 |
retries = 0
|
44 |
max_retries = 3
|
45 |
while retries < max_retries:
|
|
|
53 |
if retries == max_retries:
|
54 |
print("Insertion failed after exceeding the maximum number of retries")
|
55 |
|
56 |
+
|
57 |
cls = "mix"
|
58 |
WORKING_DIR = f"../{cls}"
|
59 |
|
60 |
if not os.path.exists(WORKING_DIR):
|
61 |
os.mkdir(WORKING_DIR)
|
62 |
|
63 |
+
rag = LightRAG(
|
64 |
+
working_dir=WORKING_DIR,
|
65 |
+
llm_model_func=llm_model_func,
|
66 |
+
embedding_func=EmbeddingFunc(
|
67 |
+
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
68 |
+
),
|
69 |
+
)
|
|
|
70 |
|
71 |
insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
|
reproduce/Step_2.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
-
import os
|
2 |
import json
|
3 |
from openai import OpenAI
|
4 |
from transformers import GPT2Tokenizer
|
5 |
|
|
|
6 |
def openai_complete_if_cache(
|
7 |
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
8 |
) -> str:
|
@@ -19,24 +19,26 @@ def openai_complete_if_cache(
|
|
19 |
)
|
20 |
return response.choices[0].message.content
|
21 |
|
22 |
-
|
|
|
|
|
23 |
|
24 |
def get_summary(context, tot_tokens=2000):
|
25 |
tokens = tokenizer.tokenize(context)
|
26 |
half_tokens = tot_tokens // 2
|
27 |
|
28 |
-
start_tokens = tokens[1000:1000 + half_tokens]
|
29 |
-
end_tokens = tokens[-(1000 + half_tokens):1000]
|
30 |
|
31 |
summary_tokens = start_tokens + end_tokens
|
32 |
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
33 |
-
|
34 |
return summary
|
35 |
|
36 |
|
37 |
-
clses = [
|
38 |
for cls in clses:
|
39 |
-
with open(f
|
40 |
unique_contexts = json.load(f)
|
41 |
|
42 |
summaries = [get_summary(context) for context in unique_contexts]
|
@@ -67,10 +69,10 @@ for cls in clses:
|
|
67 |
...
|
68 |
"""
|
69 |
|
70 |
-
result = openai_complete_if_cache(model=
|
71 |
|
72 |
file_path = f"../datasets/questions/{cls}_questions.txt"
|
73 |
with open(file_path, "w") as file:
|
74 |
file.write(result)
|
75 |
|
76 |
-
print(f"{cls}_questions written to {file_path}")
|
|
|
|
|
1 |
import json
|
2 |
from openai import OpenAI
|
3 |
from transformers import GPT2Tokenizer
|
4 |
|
5 |
+
|
6 |
def openai_complete_if_cache(
|
7 |
model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
|
8 |
) -> str:
|
|
|
19 |
)
|
20 |
return response.choices[0].message.content
|
21 |
|
22 |
+
|
23 |
+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
24 |
+
|
25 |
|
26 |
def get_summary(context, tot_tokens=2000):
|
27 |
tokens = tokenizer.tokenize(context)
|
28 |
half_tokens = tot_tokens // 2
|
29 |
|
30 |
+
start_tokens = tokens[1000 : 1000 + half_tokens]
|
31 |
+
end_tokens = tokens[-(1000 + half_tokens) : 1000]
|
32 |
|
33 |
summary_tokens = start_tokens + end_tokens
|
34 |
summary = tokenizer.convert_tokens_to_string(summary_tokens)
|
35 |
+
|
36 |
return summary
|
37 |
|
38 |
|
39 |
+
clses = ["agriculture"]
|
40 |
for cls in clses:
|
41 |
+
with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f:
|
42 |
unique_contexts = json.load(f)
|
43 |
|
44 |
summaries = [get_summary(context) for context in unique_contexts]
|
|
|
69 |
...
|
70 |
"""
|
71 |
|
72 |
+
result = openai_complete_if_cache(model="gpt-4o", prompt=prompt)
|
73 |
|
74 |
file_path = f"../datasets/questions/{cls}_questions.txt"
|
75 |
with open(file_path, "w") as file:
|
76 |
file.write(result)
|
77 |
|
78 |
+
print(f"{cls}_questions written to {file_path}")
|
reproduce/Step_3.py
CHANGED
@@ -4,16 +4,18 @@ import asyncio
|
|
4 |
from lightrag import LightRAG, QueryParam
|
5 |
from tqdm import tqdm
|
6 |
|
|
|
7 |
def extract_queries(file_path):
|
8 |
-
with open(file_path,
|
9 |
data = f.read()
|
10 |
-
|
11 |
-
data = data.replace('**', '')
|
12 |
|
13 |
-
|
|
|
|
|
14 |
|
15 |
return queries
|
16 |
|
|
|
17 |
async def process_query(query_text, rag_instance, query_param):
|
18 |
try:
|
19 |
result, context = await rag_instance.aquery(query_text, param=query_param)
|
@@ -21,6 +23,7 @@ async def process_query(query_text, rag_instance, query_param):
|
|
21 |
except Exception as e:
|
22 |
return None, {"query": query_text, "error": str(e)}
|
23 |
|
|
|
24 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
25 |
try:
|
26 |
loop = asyncio.get_event_loop()
|
@@ -29,15 +32,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
29 |
asyncio.set_event_loop(loop)
|
30 |
return loop
|
31 |
|
32 |
-
|
|
|
|
|
|
|
33 |
loop = always_get_an_event_loop()
|
34 |
|
35 |
-
with open(output_file,
|
|
|
|
|
36 |
result_file.write("[\n")
|
37 |
first_entry = True
|
38 |
|
39 |
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
40 |
-
result, error = loop.run_until_complete(
|
|
|
|
|
41 |
|
42 |
if result:
|
43 |
if not first_entry:
|
@@ -50,6 +60,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
|
|
50 |
|
51 |
result_file.write("\n]")
|
52 |
|
|
|
53 |
if __name__ == "__main__":
|
54 |
cls = "agriculture"
|
55 |
mode = "hybrid"
|
@@ -59,4 +70,6 @@ if __name__ == "__main__":
|
|
59 |
query_param = QueryParam(mode=mode)
|
60 |
|
61 |
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
|
62 |
-
run_queries_and_save_to_json(
|
|
|
|
|
|
4 |
from lightrag import LightRAG, QueryParam
|
5 |
from tqdm import tqdm
|
6 |
|
7 |
+
|
8 |
def extract_queries(file_path):
|
9 |
+
with open(file_path, "r") as f:
|
10 |
data = f.read()
|
|
|
|
|
11 |
|
12 |
+
data = data.replace("**", "")
|
13 |
+
|
14 |
+
queries = re.findall(r"- Question \d+: (.+)", data)
|
15 |
|
16 |
return queries
|
17 |
|
18 |
+
|
19 |
async def process_query(query_text, rag_instance, query_param):
|
20 |
try:
|
21 |
result, context = await rag_instance.aquery(query_text, param=query_param)
|
|
|
23 |
except Exception as e:
|
24 |
return None, {"query": query_text, "error": str(e)}
|
25 |
|
26 |
+
|
27 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
28 |
try:
|
29 |
loop = asyncio.get_event_loop()
|
|
|
32 |
asyncio.set_event_loop(loop)
|
33 |
return loop
|
34 |
|
35 |
+
|
36 |
+
def run_queries_and_save_to_json(
|
37 |
+
queries, rag_instance, query_param, output_file, error_file
|
38 |
+
):
|
39 |
loop = always_get_an_event_loop()
|
40 |
|
41 |
+
with open(output_file, "a", encoding="utf-8") as result_file, open(
|
42 |
+
error_file, "a", encoding="utf-8"
|
43 |
+
) as err_file:
|
44 |
result_file.write("[\n")
|
45 |
first_entry = True
|
46 |
|
47 |
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
48 |
+
result, error = loop.run_until_complete(
|
49 |
+
process_query(query_text, rag_instance, query_param)
|
50 |
+
)
|
51 |
|
52 |
if result:
|
53 |
if not first_entry:
|
|
|
60 |
|
61 |
result_file.write("\n]")
|
62 |
|
63 |
+
|
64 |
if __name__ == "__main__":
|
65 |
cls = "agriculture"
|
66 |
mode = "hybrid"
|
|
|
70 |
query_param = QueryParam(mode=mode)
|
71 |
|
72 |
queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
|
73 |
+
run_queries_and_save_to_json(
|
74 |
+
queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json"
|
75 |
+
)
|
reproduce/Step_3_openai_compatible.py
CHANGED
@@ -8,6 +8,7 @@ from lightrag.llm import openai_complete_if_cache, openai_embedding
|
|
8 |
from lightrag.utils import EmbeddingFunc
|
9 |
import numpy as np
|
10 |
|
|
|
11 |
## For Upstage API
|
12 |
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
13 |
async def llm_model_func(
|
@@ -20,28 +21,33 @@ async def llm_model_func(
|
|
20 |
history_messages=history_messages,
|
21 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
22 |
base_url="https://api.upstage.ai/v1/solar",
|
23 |
-
**kwargs
|
24 |
)
|
25 |
|
|
|
26 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
27 |
return await openai_embedding(
|
28 |
texts,
|
29 |
model="solar-embedding-1-large-query",
|
30 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
31 |
-
base_url="https://api.upstage.ai/v1/solar"
|
32 |
)
|
|
|
|
|
33 |
## /For Upstage API
|
34 |
|
|
|
35 |
def extract_queries(file_path):
|
36 |
-
with open(file_path,
|
37 |
data = f.read()
|
38 |
-
|
39 |
-
data = data.replace('**', '')
|
40 |
|
41 |
-
|
|
|
|
|
42 |
|
43 |
return queries
|
44 |
|
|
|
45 |
async def process_query(query_text, rag_instance, query_param):
|
46 |
try:
|
47 |
result, context = await rag_instance.aquery(query_text, param=query_param)
|
@@ -49,6 +55,7 @@ async def process_query(query_text, rag_instance, query_param):
|
|
49 |
except Exception as e:
|
50 |
return None, {"query": query_text, "error": str(e)}
|
51 |
|
|
|
52 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
53 |
try:
|
54 |
loop = asyncio.get_event_loop()
|
@@ -57,15 +64,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
|
57 |
asyncio.set_event_loop(loop)
|
58 |
return loop
|
59 |
|
60 |
-
|
|
|
|
|
|
|
61 |
loop = always_get_an_event_loop()
|
62 |
|
63 |
-
with open(output_file,
|
|
|
|
|
64 |
result_file.write("[\n")
|
65 |
first_entry = True
|
66 |
|
67 |
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
68 |
-
result, error = loop.run_until_complete(
|
|
|
|
|
69 |
|
70 |
if result:
|
71 |
if not first_entry:
|
@@ -78,22 +92,24 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
|
|
78 |
|
79 |
result_file.write("\n]")
|
80 |
|
|
|
81 |
if __name__ == "__main__":
|
82 |
cls = "mix"
|
83 |
mode = "hybrid"
|
84 |
WORKING_DIR = f"../{cls}"
|
85 |
|
86 |
rag = LightRAG(working_dir=WORKING_DIR)
|
87 |
-
rag = LightRAG(
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
)
|
95 |
query_param = QueryParam(mode=mode)
|
96 |
|
97 |
-
base_dir=
|
98 |
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
|
99 |
-
run_queries_and_save_to_json(
|
|
|
|
|
|
8 |
from lightrag.utils import EmbeddingFunc
|
9 |
import numpy as np
|
10 |
|
11 |
+
|
12 |
## For Upstage API
|
13 |
# please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
|
14 |
async def llm_model_func(
|
|
|
21 |
history_messages=history_messages,
|
22 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
23 |
base_url="https://api.upstage.ai/v1/solar",
|
24 |
+
**kwargs,
|
25 |
)
|
26 |
|
27 |
+
|
28 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
29 |
return await openai_embedding(
|
30 |
texts,
|
31 |
model="solar-embedding-1-large-query",
|
32 |
api_key=os.getenv("UPSTAGE_API_KEY"),
|
33 |
+
base_url="https://api.upstage.ai/v1/solar",
|
34 |
)
|
35 |
+
|
36 |
+
|
37 |
## /For Upstage API
|
38 |
|
39 |
+
|
40 |
def extract_queries(file_path):
|
41 |
+
with open(file_path, "r") as f:
|
42 |
data = f.read()
|
|
|
|
|
43 |
|
44 |
+
data = data.replace("**", "")
|
45 |
+
|
46 |
+
queries = re.findall(r"- Question \d+: (.+)", data)
|
47 |
|
48 |
return queries
|
49 |
|
50 |
+
|
51 |
async def process_query(query_text, rag_instance, query_param):
|
52 |
try:
|
53 |
result, context = await rag_instance.aquery(query_text, param=query_param)
|
|
|
55 |
except Exception as e:
|
56 |
return None, {"query": query_text, "error": str(e)}
|
57 |
|
58 |
+
|
59 |
def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
|
60 |
try:
|
61 |
loop = asyncio.get_event_loop()
|
|
|
64 |
asyncio.set_event_loop(loop)
|
65 |
return loop
|
66 |
|
67 |
+
|
68 |
+
def run_queries_and_save_to_json(
|
69 |
+
queries, rag_instance, query_param, output_file, error_file
|
70 |
+
):
|
71 |
loop = always_get_an_event_loop()
|
72 |
|
73 |
+
with open(output_file, "a", encoding="utf-8") as result_file, open(
|
74 |
+
error_file, "a", encoding="utf-8"
|
75 |
+
) as err_file:
|
76 |
result_file.write("[\n")
|
77 |
first_entry = True
|
78 |
|
79 |
for query_text in tqdm(queries, desc="Processing queries", unit="query"):
|
80 |
+
result, error = loop.run_until_complete(
|
81 |
+
process_query(query_text, rag_instance, query_param)
|
82 |
+
)
|
83 |
|
84 |
if result:
|
85 |
if not first_entry:
|
|
|
92 |
|
93 |
result_file.write("\n]")
|
94 |
|
95 |
+
|
96 |
if __name__ == "__main__":
|
97 |
cls = "mix"
|
98 |
mode = "hybrid"
|
99 |
WORKING_DIR = f"../{cls}"
|
100 |
|
101 |
rag = LightRAG(working_dir=WORKING_DIR)
|
102 |
+
rag = LightRAG(
|
103 |
+
working_dir=WORKING_DIR,
|
104 |
+
llm_model_func=llm_model_func,
|
105 |
+
embedding_func=EmbeddingFunc(
|
106 |
+
embedding_dim=4096, max_token_size=8192, func=embedding_func
|
107 |
+
),
|
108 |
+
)
|
|
|
109 |
query_param = QueryParam(mode=mode)
|
110 |
|
111 |
+
base_dir = "../datasets/questions"
|
112 |
queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
|
113 |
+
run_queries_and_save_to_json(
|
114 |
+
queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json"
|
115 |
+
)
|
requirements.txt
CHANGED
@@ -1,13 +1,13 @@
|
|
|
|
1 |
aioboto3
|
2 |
-
openai
|
3 |
-
tiktoken
|
4 |
-
networkx
|
5 |
graspologic
|
6 |
-
nano-vectordb
|
7 |
hnswlib
|
8 |
-
|
|
|
|
|
|
|
9 |
tenacity
|
10 |
-
|
11 |
torch
|
12 |
-
|
13 |
-
|
|
|
1 |
+
accelerate
|
2 |
aioboto3
|
|
|
|
|
|
|
3 |
graspologic
|
|
|
4 |
hnswlib
|
5 |
+
nano-vectordb
|
6 |
+
networkx
|
7 |
+
ollama
|
8 |
+
openai
|
9 |
tenacity
|
10 |
+
tiktoken
|
11 |
torch
|
12 |
+
transformers
|
13 |
+
xxhash
|