zrguo commited on
Commit
5482815
·
2 Parent(s): 8de9098 df22b26

Merge pull request #56 from sank8-2/dev

Browse files

chore: added pre-commit-hooks and ruff formatting for commit-hooks

.gitignore CHANGED
@@ -1,4 +1,5 @@
1
  __pycache__
2
  *.egg-info
3
  dickens/
4
- book.txt
 
 
1
  __pycache__
2
  *.egg-info
3
  dickens/
4
+ book.txt
5
+ lightrag-dev/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: v5.0.0
4
+ hooks:
5
+ - id: trailing-whitespace
6
+ - id: end-of-file-fixer
7
+ - id: requirements-txt-fixer
8
+
9
+
10
+ - repo: https://github.com/astral-sh/ruff-pre-commit
11
+ rev: v0.6.4
12
+ hooks:
13
+ - id: ruff-format
14
+ - id: ruff
15
+ args: [--fix]
16
+
17
+
18
+ - repo: https://github.com/mgedmin/check-manifest
19
+ rev: "0.49"
20
+ hooks:
21
+ - id: check-manifest
22
+ stages: [manual]
README.md CHANGED
@@ -16,16 +16,16 @@
16
  <a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
17
  <a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
18
  </p>
19
-
20
  This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
21
  ![请添加图片描述](https://i-blog.csdnimg.cn/direct/b2aaf634151b4706892693ffb43d9093.png)
22
  </div>
23
 
24
- ## 🎉 News
25
  - [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
26
  - [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
27
- - [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
28
- - [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
29
 
30
  ## Install
31
 
@@ -92,7 +92,7 @@ print(rag.query("What are the top themes in this story?", param=QueryParam(mode=
92
  <details>
93
  <summary> Using Open AI-like APIs </summary>
94
 
95
- LightRAG also support Open AI-like chat/embeddings APIs:
96
  ```python
97
  async def llm_model_func(
98
  prompt, system_prompt=None, history_messages=[], **kwargs
@@ -129,7 +129,7 @@ rag = LightRAG(
129
 
130
  <details>
131
  <summary> Using Hugging Face Models </summary>
132
-
133
  If you want to use Hugging Face models, you only need to set LightRAG as follows:
134
  ```python
135
  from lightrag.llm import hf_model_complete, hf_embedding
@@ -145,7 +145,7 @@ rag = LightRAG(
145
  embedding_dim=384,
146
  max_token_size=5000,
147
  func=lambda texts: hf_embedding(
148
- texts,
149
  tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
150
  embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
151
  )
@@ -157,7 +157,7 @@ rag = LightRAG(
157
  <details>
158
  <summary> Using Ollama Models </summary>
159
  If you want to use Ollama models, you only need to set LightRAG as follows:
160
-
161
  ```python
162
  from lightrag.llm import ollama_model_complete, ollama_embedding
163
 
@@ -171,7 +171,7 @@ rag = LightRAG(
171
  embedding_dim=768,
172
  max_token_size=8192,
173
  func=lambda texts: ollama_embedding(
174
- texts,
175
  embed_model="nomic-embed-text"
176
  )
177
  ),
@@ -196,14 +196,14 @@ with open("./newText.txt") as f:
196
  ```
197
  ## Evaluation
198
  ### Dataset
199
- The dataset used in LightRAG can be download from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
200
 
201
  ### Generate Query
202
- LightRAG uses the following prompt to generate high-level queries, with the corresponding code located in `example/generate_query.py`.
203
 
204
  <details>
205
  <summary> Prompt </summary>
206
-
207
  ```python
208
  Given the following description of a dataset:
209
 
@@ -228,18 +228,18 @@ Output the results in the following structure:
228
  ...
229
  ```
230
  </details>
231
-
232
  ### Batch Eval
233
  To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
234
 
235
  <details>
236
  <summary> Prompt </summary>
237
-
238
  ```python
239
  ---Role---
240
  You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
241
  ---Goal---
242
- You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
243
 
244
  - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
245
  - **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
@@ -303,7 +303,7 @@ Output your evaluation in the following JSON format:
303
  | **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% |
304
  | **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% |
305
 
306
- ## Reproduce
307
  All the code can be found in the `./reproduce` directory.
308
 
309
  ### Step-0 Extract Unique Contexts
@@ -311,7 +311,7 @@ First, we need to extract unique contexts in the datasets.
311
 
312
  <details>
313
  <summary> Code </summary>
314
-
315
  ```python
316
  def extract_unique_contexts(input_directory, output_directory):
317
 
@@ -370,12 +370,12 @@ For the extracted contexts, we insert them into the LightRAG system.
370
 
371
  <details>
372
  <summary> Code </summary>
373
-
374
  ```python
375
  def insert_text(rag, file_path):
376
  with open(file_path, mode='r') as f:
377
  unique_contexts = json.load(f)
378
-
379
  retries = 0
380
  max_retries = 3
381
  while retries < max_retries:
@@ -393,11 +393,11 @@ def insert_text(rag, file_path):
393
 
394
  ### Step-2 Generate Queries
395
 
396
- We extract tokens from both the first half and the second half of each context in the dataset, then combine them as the dataset description to generate queries.
397
 
398
  <details>
399
  <summary> Code </summary>
400
-
401
  ```python
402
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
403
 
@@ -410,7 +410,7 @@ def get_summary(context, tot_tokens=2000):
410
 
411
  summary_tokens = start_tokens + end_tokens
412
  summary = tokenizer.convert_tokens_to_string(summary_tokens)
413
-
414
  return summary
415
  ```
416
  </details>
@@ -420,12 +420,12 @@ For the queries generated in Step-2, we will extract them and query LightRAG.
420
 
421
  <details>
422
  <summary> Code </summary>
423
-
424
  ```python
425
  def extract_queries(file_path):
426
  with open(file_path, 'r') as f:
427
  data = f.read()
428
-
429
  data = data.replace('**', '')
430
 
431
  queries = re.findall(r'- Question \d+: (.+)', data)
@@ -479,7 +479,7 @@ def extract_queries(file_path):
479
 
480
  ```python
481
  @article{guo2024lightrag,
482
- title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
483
  author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
484
  year={2024},
485
  eprint={2410.05779},
 
16
  <a href="https://pypi.org/project/lightrag-hku/"><img src="https://img.shields.io/pypi/v/lightrag-hku.svg"></a>
17
  <a href="https://pepy.tech/project/lightrag-hku"><img src="https://static.pepy.tech/badge/lightrag-hku/month"></a>
18
  </p>
19
+
20
  This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
21
  ![请添加图片描述](https://i-blog.csdnimg.cn/direct/b2aaf634151b4706892693ffb43d9093.png)
22
  </div>
23
 
24
+ ## 🎉 News
25
  - [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
26
  - [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
27
+ - [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
28
+ - [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
29
 
30
  ## Install
31
 
 
92
  <details>
93
  <summary> Using Open AI-like APIs </summary>
94
 
95
+ LightRAG also supports Open AI-like chat/embeddings APIs:
96
  ```python
97
  async def llm_model_func(
98
  prompt, system_prompt=None, history_messages=[], **kwargs
 
129
 
130
  <details>
131
  <summary> Using Hugging Face Models </summary>
132
+
133
  If you want to use Hugging Face models, you only need to set LightRAG as follows:
134
  ```python
135
  from lightrag.llm import hf_model_complete, hf_embedding
 
145
  embedding_dim=384,
146
  max_token_size=5000,
147
  func=lambda texts: hf_embedding(
148
+ texts,
149
  tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
150
  embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
151
  )
 
157
  <details>
158
  <summary> Using Ollama Models </summary>
159
  If you want to use Ollama models, you only need to set LightRAG as follows:
160
+
161
  ```python
162
  from lightrag.llm import ollama_model_complete, ollama_embedding
163
 
 
171
  embedding_dim=768,
172
  max_token_size=8192,
173
  func=lambda texts: ollama_embedding(
174
+ texts,
175
  embed_model="nomic-embed-text"
176
  )
177
  ),
 
196
  ```
197
  ## Evaluation
198
  ### Dataset
199
+ The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
200
 
201
  ### Generate Query
202
+ LightRAG uses the following prompt to generate high-level queries, with the corresponding code in `example/generate_query.py`.
203
 
204
  <details>
205
  <summary> Prompt </summary>
206
+
207
  ```python
208
  Given the following description of a dataset:
209
 
 
228
  ...
229
  ```
230
  </details>
231
+
232
  ### Batch Eval
233
  To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
234
 
235
  <details>
236
  <summary> Prompt </summary>
237
+
238
  ```python
239
  ---Role---
240
  You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
241
  ---Goal---
242
+ You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
243
 
244
  - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
245
  - **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
 
303
  | **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% |
304
  | **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% |
305
 
306
+ ## Reproduce
307
  All the code can be found in the `./reproduce` directory.
308
 
309
  ### Step-0 Extract Unique Contexts
 
311
 
312
  <details>
313
  <summary> Code </summary>
314
+
315
  ```python
316
  def extract_unique_contexts(input_directory, output_directory):
317
 
 
370
 
371
  <details>
372
  <summary> Code </summary>
373
+
374
  ```python
375
  def insert_text(rag, file_path):
376
  with open(file_path, mode='r') as f:
377
  unique_contexts = json.load(f)
378
+
379
  retries = 0
380
  max_retries = 3
381
  while retries < max_retries:
 
393
 
394
  ### Step-2 Generate Queries
395
 
396
+ We extract tokens from the first and the second half of each context in the dataset, then combine them as dataset descriptions to generate queries.
397
 
398
  <details>
399
  <summary> Code </summary>
400
+
401
  ```python
402
  tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
403
 
 
410
 
411
  summary_tokens = start_tokens + end_tokens
412
  summary = tokenizer.convert_tokens_to_string(summary_tokens)
413
+
414
  return summary
415
  ```
416
  </details>
 
420
 
421
  <details>
422
  <summary> Code </summary>
423
+
424
  ```python
425
  def extract_queries(file_path):
426
  with open(file_path, 'r') as f:
427
  data = f.read()
428
+
429
  data = data.replace('**', '')
430
 
431
  queries = re.findall(r'- Question \d+: (.+)', data)
 
479
 
480
  ```python
481
  @article{guo2024lightrag,
482
+ title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
483
  author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
484
  year={2024},
485
  eprint={2410.05779},
examples/batch_eval.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import re
3
  import json
4
  import jsonlines
@@ -9,28 +8,28 @@ from openai import OpenAI
9
  def batch_eval(query_file, result1_file, result2_file, output_file_path):
10
  client = OpenAI()
11
 
12
- with open(query_file, 'r') as f:
13
  data = f.read()
14
 
15
- queries = re.findall(r'- Question \d+: (.+)', data)
16
 
17
- with open(result1_file, 'r') as f:
18
  answers1 = json.load(f)
19
- answers1 = [i['result'] for i in answers1]
20
 
21
- with open(result2_file, 'r') as f:
22
  answers2 = json.load(f)
23
- answers2 = [i['result'] for i in answers2]
24
 
25
  requests = []
26
  for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
27
- sys_prompt = f"""
28
  ---Role---
29
  You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
30
  """
31
 
32
  prompt = f"""
33
- You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
34
 
35
  - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
36
  - **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
@@ -69,7 +68,6 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
69
  }}
70
  """
71
 
72
-
73
  request_data = {
74
  "custom_id": f"request-{i+1}",
75
  "method": "POST",
@@ -78,22 +76,21 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
78
  "model": "gpt-4o-mini",
79
  "messages": [
80
  {"role": "system", "content": sys_prompt},
81
- {"role": "user", "content": prompt}
82
  ],
83
- }
84
  }
85
-
86
  requests.append(request_data)
87
 
88
- with jsonlines.open(output_file_path, mode='w') as writer:
89
  for request in requests:
90
  writer.write(request)
91
 
92
  print(f"Batch API requests written to {output_file_path}")
93
 
94
  batch_input_file = client.files.create(
95
- file=open(output_file_path, "rb"),
96
- purpose="batch"
97
  )
98
  batch_input_file_id = batch_input_file.id
99
 
@@ -101,12 +98,11 @@ def batch_eval(query_file, result1_file, result2_file, output_file_path):
101
  input_file_id=batch_input_file_id,
102
  endpoint="/v1/chat/completions",
103
  completion_window="24h",
104
- metadata={
105
- "description": "nightly eval job"
106
- }
107
  )
108
 
109
- print(f'Batch {batch.id} has been created.')
 
110
 
111
  if __name__ == "__main__":
112
- batch_eval()
 
 
1
  import re
2
  import json
3
  import jsonlines
 
8
  def batch_eval(query_file, result1_file, result2_file, output_file_path):
9
  client = OpenAI()
10
 
11
+ with open(query_file, "r") as f:
12
  data = f.read()
13
 
14
+ queries = re.findall(r"- Question \d+: (.+)", data)
15
 
16
+ with open(result1_file, "r") as f:
17
  answers1 = json.load(f)
18
+ answers1 = [i["result"] for i in answers1]
19
 
20
+ with open(result2_file, "r") as f:
21
  answers2 = json.load(f)
22
+ answers2 = [i["result"] for i in answers2]
23
 
24
  requests = []
25
  for i, (query, answer1, answer2) in enumerate(zip(queries, answers1, answers2)):
26
+ sys_prompt = """
27
  ---Role---
28
  You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
29
  """
30
 
31
  prompt = f"""
32
+ You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
33
 
34
  - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
35
  - **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
 
68
  }}
69
  """
70
 
 
71
  request_data = {
72
  "custom_id": f"request-{i+1}",
73
  "method": "POST",
 
76
  "model": "gpt-4o-mini",
77
  "messages": [
78
  {"role": "system", "content": sys_prompt},
79
+ {"role": "user", "content": prompt},
80
  ],
81
+ },
82
  }
83
+
84
  requests.append(request_data)
85
 
86
+ with jsonlines.open(output_file_path, mode="w") as writer:
87
  for request in requests:
88
  writer.write(request)
89
 
90
  print(f"Batch API requests written to {output_file_path}")
91
 
92
  batch_input_file = client.files.create(
93
+ file=open(output_file_path, "rb"), purpose="batch"
 
94
  )
95
  batch_input_file_id = batch_input_file.id
96
 
 
98
  input_file_id=batch_input_file_id,
99
  endpoint="/v1/chat/completions",
100
  completion_window="24h",
101
+ metadata={"description": "nightly eval job"},
 
 
102
  )
103
 
104
+ print(f"Batch {batch.id} has been created.")
105
+
106
 
107
  if __name__ == "__main__":
108
+ batch_eval()
examples/generate_query.py CHANGED
@@ -1,9 +1,8 @@
1
- import os
2
-
3
  from openai import OpenAI
4
 
5
  # os.environ["OPENAI_API_KEY"] = ""
6
 
 
7
  def openai_complete_if_cache(
8
  model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
9
  ) -> str:
@@ -47,10 +46,10 @@ if __name__ == "__main__":
47
  ...
48
  """
49
 
50
- result = openai_complete_if_cache(model='gpt-4o-mini', prompt=prompt)
51
 
52
- file_path = f"./queries.txt"
53
  with open(file_path, "w") as file:
54
  file.write(result)
55
 
56
- print(f"Queries written to {file_path}")
 
 
 
1
  from openai import OpenAI
2
 
3
  # os.environ["OPENAI_API_KEY"] = ""
4
 
5
+
6
  def openai_complete_if_cache(
7
  model="gpt-4o-mini", prompt=None, system_prompt=None, history_messages=[], **kwargs
8
  ) -> str:
 
46
  ...
47
  """
48
 
49
+ result = openai_complete_if_cache(model="gpt-4o-mini", prompt=prompt)
50
 
51
+ file_path = "./queries.txt"
52
  with open(file_path, "w") as file:
53
  file.write(result)
54
 
55
+ print(f"Queries written to {file_path}")
examples/lightrag_azure_openai_demo.py CHANGED
@@ -122,4 +122,4 @@ print("\nResult (Global):")
122
  print(rag.query(query_text, param=QueryParam(mode="global")))
123
 
124
  print("\nResult (Hybrid):")
125
- print(rag.query(query_text, param=QueryParam(mode="hybrid")))
 
122
  print(rag.query(query_text, param=QueryParam(mode="global")))
123
 
124
  print("\nResult (Hybrid):")
125
+ print(rag.query(query_text, param=QueryParam(mode="hybrid")))
examples/lightrag_bedrock_demo.py CHANGED
@@ -20,13 +20,11 @@ rag = LightRAG(
20
  llm_model_func=bedrock_complete,
21
  llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
22
  embedding_func=EmbeddingFunc(
23
- embedding_dim=1024,
24
- max_token_size=8192,
25
- func=bedrock_embedding
26
- )
27
  )
28
 
29
- with open("./book.txt", 'r', encoding='utf-8') as f:
30
  rag.insert(f.read())
31
 
32
  for mode in ["naive", "local", "global", "hybrid"]:
@@ -34,8 +32,5 @@ for mode in ["naive", "local", "global", "hybrid"]:
34
  print(f"| {mode.capitalize()} |")
35
  print("+-" + "-" * len(mode) + "-+\n")
36
  print(
37
- rag.query(
38
- "What are the top themes in this story?",
39
- param=QueryParam(mode=mode)
40
- )
41
  )
 
20
  llm_model_func=bedrock_complete,
21
  llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
22
  embedding_func=EmbeddingFunc(
23
+ embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
24
+ ),
 
 
25
  )
26
 
27
+ with open("./book.txt", "r", encoding="utf-8") as f:
28
  rag.insert(f.read())
29
 
30
  for mode in ["naive", "local", "global", "hybrid"]:
 
32
  print(f"| {mode.capitalize()} |")
33
  print("+-" + "-" * len(mode) + "-+\n")
34
  print(
35
+ rag.query("What are the top themes in this story?", param=QueryParam(mode=mode))
 
 
 
36
  )
examples/lightrag_hf_demo.py CHANGED
@@ -1,10 +1,9 @@
1
  import os
2
- import sys
3
 
4
  from lightrag import LightRAG, QueryParam
5
  from lightrag.llm import hf_model_complete, hf_embedding
6
  from lightrag.utils import EmbeddingFunc
7
- from transformers import AutoModel,AutoTokenizer
8
 
9
  WORKING_DIR = "./dickens"
10
 
@@ -13,16 +12,20 @@ if not os.path.exists(WORKING_DIR):
13
 
14
  rag = LightRAG(
15
  working_dir=WORKING_DIR,
16
- llm_model_func=hf_model_complete,
17
- llm_model_name='meta-llama/Llama-3.1-8B-Instruct',
18
  embedding_func=EmbeddingFunc(
19
  embedding_dim=384,
20
  max_token_size=5000,
21
  func=lambda texts: hf_embedding(
22
- texts,
23
- tokenizer=AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2"),
24
- embed_model=AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
25
- )
 
 
 
 
26
  ),
27
  )
28
 
@@ -31,13 +34,21 @@ with open("./book.txt") as f:
31
  rag.insert(f.read())
32
 
33
  # Perform naive search
34
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
 
 
35
 
36
  # Perform local search
37
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
 
 
38
 
39
  # Perform global search
40
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
 
 
41
 
42
  # Perform hybrid search
43
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
 
 
 
1
  import os
 
2
 
3
  from lightrag import LightRAG, QueryParam
4
  from lightrag.llm import hf_model_complete, hf_embedding
5
  from lightrag.utils import EmbeddingFunc
6
+ from transformers import AutoModel, AutoTokenizer
7
 
8
  WORKING_DIR = "./dickens"
9
 
 
12
 
13
  rag = LightRAG(
14
  working_dir=WORKING_DIR,
15
+ llm_model_func=hf_model_complete,
16
+ llm_model_name="meta-llama/Llama-3.1-8B-Instruct",
17
  embedding_func=EmbeddingFunc(
18
  embedding_dim=384,
19
  max_token_size=5000,
20
  func=lambda texts: hf_embedding(
21
+ texts,
22
+ tokenizer=AutoTokenizer.from_pretrained(
23
+ "sentence-transformers/all-MiniLM-L6-v2"
24
+ ),
25
+ embed_model=AutoModel.from_pretrained(
26
+ "sentence-transformers/all-MiniLM-L6-v2"
27
+ ),
28
+ ),
29
  ),
30
  )
31
 
 
34
  rag.insert(f.read())
35
 
36
  # Perform naive search
37
+ print(
38
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
39
+ )
40
 
41
  # Perform local search
42
+ print(
43
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
44
+ )
45
 
46
  # Perform global search
47
+ print(
48
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
49
+ )
50
 
51
  # Perform hybrid search
52
+ print(
53
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
54
+ )
examples/lightrag_ollama_demo.py CHANGED
@@ -11,15 +11,12 @@ if not os.path.exists(WORKING_DIR):
11
 
12
  rag = LightRAG(
13
  working_dir=WORKING_DIR,
14
- llm_model_func=ollama_model_complete,
15
- llm_model_name='your_model_name',
16
  embedding_func=EmbeddingFunc(
17
  embedding_dim=768,
18
  max_token_size=8192,
19
- func=lambda texts: ollama_embedding(
20
- texts,
21
- embed_model="nomic-embed-text"
22
- )
23
  ),
24
  )
25
 
@@ -28,13 +25,21 @@ with open("./book.txt") as f:
28
  rag.insert(f.read())
29
 
30
  # Perform naive search
31
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
 
 
32
 
33
  # Perform local search
34
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
 
 
35
 
36
  # Perform global search
37
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
 
 
38
 
39
  # Perform hybrid search
40
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
 
 
 
11
 
12
  rag = LightRAG(
13
  working_dir=WORKING_DIR,
14
+ llm_model_func=ollama_model_complete,
15
+ llm_model_name="your_model_name",
16
  embedding_func=EmbeddingFunc(
17
  embedding_dim=768,
18
  max_token_size=8192,
19
+ func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
 
 
 
20
  ),
21
  )
22
 
 
25
  rag.insert(f.read())
26
 
27
  # Perform naive search
28
+ print(
29
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
30
+ )
31
 
32
  # Perform local search
33
+ print(
34
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
35
+ )
36
 
37
  # Perform global search
38
+ print(
39
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
40
+ )
41
 
42
  # Perform hybrid search
43
+ print(
44
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
45
+ )
examples/lightrag_openai_compatible_demo.py CHANGED
@@ -6,10 +6,11 @@ from lightrag.utils import EmbeddingFunc
6
  import numpy as np
7
 
8
  WORKING_DIR = "./dickens"
9
-
10
  if not os.path.exists(WORKING_DIR):
11
  os.mkdir(WORKING_DIR)
12
 
 
13
  async def llm_model_func(
14
  prompt, system_prompt=None, history_messages=[], **kwargs
15
  ) -> str:
@@ -20,17 +21,19 @@ async def llm_model_func(
20
  history_messages=history_messages,
21
  api_key=os.getenv("UPSTAGE_API_KEY"),
22
  base_url="https://api.upstage.ai/v1/solar",
23
- **kwargs
24
  )
25
 
 
26
  async def embedding_func(texts: list[str]) -> np.ndarray:
27
  return await openai_embedding(
28
  texts,
29
  model="solar-embedding-1-large-query",
30
  api_key=os.getenv("UPSTAGE_API_KEY"),
31
- base_url="https://api.upstage.ai/v1/solar"
32
  )
33
 
 
34
  # function test
35
  async def test_funcs():
36
  result = await llm_model_func("How are you?")
@@ -39,6 +42,7 @@ async def test_funcs():
39
  result = await embedding_func(["How are you?"])
40
  print("embedding_func: ", result)
41
 
 
42
  asyncio.run(test_funcs())
43
 
44
 
@@ -46,10 +50,8 @@ rag = LightRAG(
46
  working_dir=WORKING_DIR,
47
  llm_model_func=llm_model_func,
48
  embedding_func=EmbeddingFunc(
49
- embedding_dim=4096,
50
- max_token_size=8192,
51
- func=embedding_func
52
- )
53
  )
54
 
55
 
@@ -57,13 +59,21 @@ with open("./book.txt") as f:
57
  rag.insert(f.read())
58
 
59
  # Perform naive search
60
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
 
 
61
 
62
  # Perform local search
63
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
 
 
64
 
65
  # Perform global search
66
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
 
 
67
 
68
  # Perform hybrid search
69
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
 
 
 
6
  import numpy as np
7
 
8
  WORKING_DIR = "./dickens"
9
+
10
  if not os.path.exists(WORKING_DIR):
11
  os.mkdir(WORKING_DIR)
12
 
13
+
14
  async def llm_model_func(
15
  prompt, system_prompt=None, history_messages=[], **kwargs
16
  ) -> str:
 
21
  history_messages=history_messages,
22
  api_key=os.getenv("UPSTAGE_API_KEY"),
23
  base_url="https://api.upstage.ai/v1/solar",
24
+ **kwargs,
25
  )
26
 
27
+
28
  async def embedding_func(texts: list[str]) -> np.ndarray:
29
  return await openai_embedding(
30
  texts,
31
  model="solar-embedding-1-large-query",
32
  api_key=os.getenv("UPSTAGE_API_KEY"),
33
+ base_url="https://api.upstage.ai/v1/solar",
34
  )
35
 
36
+
37
  # function test
38
  async def test_funcs():
39
  result = await llm_model_func("How are you?")
 
42
  result = await embedding_func(["How are you?"])
43
  print("embedding_func: ", result)
44
 
45
+
46
  asyncio.run(test_funcs())
47
 
48
 
 
50
  working_dir=WORKING_DIR,
51
  llm_model_func=llm_model_func,
52
  embedding_func=EmbeddingFunc(
53
+ embedding_dim=4096, max_token_size=8192, func=embedding_func
54
+ ),
 
 
55
  )
56
 
57
 
 
59
  rag.insert(f.read())
60
 
61
  # Perform naive search
62
+ print(
63
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
64
+ )
65
 
66
  # Perform local search
67
+ print(
68
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
69
+ )
70
 
71
  # Perform global search
72
+ print(
73
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
74
+ )
75
 
76
  # Perform hybrid search
77
+ print(
78
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
79
+ )
examples/lightrag_openai_demo.py CHANGED
@@ -1,9 +1,7 @@
1
  import os
2
- import sys
3
 
4
  from lightrag import LightRAG, QueryParam
5
- from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
6
- from transformers import AutoModel,AutoTokenizer
7
 
8
  WORKING_DIR = "./dickens"
9
 
@@ -12,7 +10,7 @@ if not os.path.exists(WORKING_DIR):
12
 
13
  rag = LightRAG(
14
  working_dir=WORKING_DIR,
15
- llm_model_func=gpt_4o_mini_complete
16
  # llm_model_func=gpt_4o_complete
17
  )
18
 
@@ -21,13 +19,21 @@ with open("./book.txt") as f:
21
  rag.insert(f.read())
22
 
23
  # Perform naive search
24
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
 
 
25
 
26
  # Perform local search
27
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
 
 
28
 
29
  # Perform global search
30
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
 
 
31
 
32
  # Perform hybrid search
33
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
 
 
 
1
  import os
 
2
 
3
  from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm import gpt_4o_mini_complete
 
5
 
6
  WORKING_DIR = "./dickens"
7
 
 
10
 
11
  rag = LightRAG(
12
  working_dir=WORKING_DIR,
13
+ llm_model_func=gpt_4o_mini_complete,
14
  # llm_model_func=gpt_4o_complete
15
  )
16
 
 
19
  rag.insert(f.read())
20
 
21
  # Perform naive search
22
+ print(
23
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
24
+ )
25
 
26
  # Perform local search
27
+ print(
28
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
29
+ )
30
 
31
  # Perform global search
32
+ print(
33
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
34
+ )
35
 
36
  # Perform hybrid search
37
+ print(
38
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
39
+ )
lightrag/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- from .lightrag import LightRAG, QueryParam
2
 
3
  __version__ = "0.0.6"
4
  __author__ = "Zirui Guo"
 
1
+ from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
2
 
3
  __version__ = "0.0.6"
4
  __author__ = "Zirui Guo"
lightrag/base.py CHANGED
@@ -12,15 +12,16 @@ TextChunkSchema = TypedDict(
12
 
13
  T = TypeVar("T")
14
 
 
15
  @dataclass
16
  class QueryParam:
17
  mode: Literal["local", "global", "hybrid", "naive"] = "global"
18
  only_need_context: bool = False
19
  response_type: str = "Multiple Paragraphs"
20
  top_k: int = 60
21
- max_token_for_text_unit: int = 4000
22
  max_token_for_global_context: int = 4000
23
- max_token_for_local_context: int = 4000
24
 
25
 
26
  @dataclass
@@ -36,6 +37,7 @@ class StorageNameSpace:
36
  """commit the storage operations after querying"""
37
  pass
38
 
 
39
  @dataclass
40
  class BaseVectorStorage(StorageNameSpace):
41
  embedding_func: EmbeddingFunc
@@ -50,6 +52,7 @@ class BaseVectorStorage(StorageNameSpace):
50
  """
51
  raise NotImplementedError
52
 
 
53
  @dataclass
54
  class BaseKVStorage(Generic[T], StorageNameSpace):
55
  async def all_keys(self) -> list[str]:
@@ -72,7 +75,7 @@ class BaseKVStorage(Generic[T], StorageNameSpace):
72
 
73
  async def drop(self):
74
  raise NotImplementedError
75
-
76
 
77
  @dataclass
78
  class BaseGraphStorage(StorageNameSpace):
@@ -113,4 +116,4 @@ class BaseGraphStorage(StorageNameSpace):
113
  raise NotImplementedError
114
 
115
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
116
- raise NotImplementedError("Node embedding is not used in lightrag.")
 
12
 
13
  T = TypeVar("T")
14
 
15
+
16
  @dataclass
17
  class QueryParam:
18
  mode: Literal["local", "global", "hybrid", "naive"] = "global"
19
  only_need_context: bool = False
20
  response_type: str = "Multiple Paragraphs"
21
  top_k: int = 60
22
+ max_token_for_text_unit: int = 4000
23
  max_token_for_global_context: int = 4000
24
+ max_token_for_local_context: int = 4000
25
 
26
 
27
  @dataclass
 
37
  """commit the storage operations after querying"""
38
  pass
39
 
40
+
41
  @dataclass
42
  class BaseVectorStorage(StorageNameSpace):
43
  embedding_func: EmbeddingFunc
 
52
  """
53
  raise NotImplementedError
54
 
55
+
56
  @dataclass
57
  class BaseKVStorage(Generic[T], StorageNameSpace):
58
  async def all_keys(self) -> list[str]:
 
75
 
76
  async def drop(self):
77
  raise NotImplementedError
78
+
79
 
80
  @dataclass
81
  class BaseGraphStorage(StorageNameSpace):
 
116
  raise NotImplementedError
117
 
118
  async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
119
+ raise NotImplementedError("Node embedding is not used in lightrag.")
lightrag/lightrag.py CHANGED
@@ -3,10 +3,12 @@ import os
3
  from dataclasses import asdict, dataclass, field
4
  from datetime import datetime
5
  from functools import partial
6
- from typing import Type, cast, Any
7
- from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
8
 
9
- from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
 
 
 
10
  from .operate import (
11
  chunking_by_token_size,
12
  extract_entities,
@@ -37,6 +39,7 @@ from .base import (
37
  QueryParam,
38
  )
39
 
 
40
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
41
  try:
42
  loop = asyncio.get_running_loop()
@@ -69,7 +72,6 @@ class LightRAG:
69
  "dimensions": 1536,
70
  "num_walks": 10,
71
  "walk_length": 40,
72
- "num_walks": 10,
73
  "window_size": 2,
74
  "iterations": 3,
75
  "random_seed": 3,
@@ -77,13 +79,13 @@ class LightRAG:
77
  )
78
 
79
  # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
80
- embedding_func: EmbeddingFunc = field(default_factory=lambda:openai_embedding)
81
  embedding_batch_num: int = 32
82
  embedding_func_max_async: int = 16
83
 
84
  # LLM
85
- llm_model_func: callable = gpt_4o_mini_complete#hf_model_complete#
86
- llm_model_name: str = 'meta-llama/Llama-3.2-1B-Instruct'#'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
87
  llm_model_max_token_size: int = 32768
88
  llm_model_max_async: int = 16
89
 
@@ -98,11 +100,11 @@ class LightRAG:
98
  addon_params: dict = field(default_factory=dict)
99
  convert_response_to_json_func: callable = convert_response_to_json
100
 
101
- def __post_init__(self):
102
  log_file = os.path.join(self.working_dir, "lightrag.log")
103
  set_logger(log_file)
104
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
105
-
106
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
107
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
108
 
@@ -133,30 +135,24 @@ class LightRAG:
133
  self.embedding_func
134
  )
135
 
136
- self.entities_vdb = (
137
- self.vector_db_storage_cls(
138
- namespace="entities",
139
- global_config=asdict(self),
140
- embedding_func=self.embedding_func,
141
- meta_fields={"entity_name"}
142
- )
143
  )
144
- self.relationships_vdb = (
145
- self.vector_db_storage_cls(
146
- namespace="relationships",
147
- global_config=asdict(self),
148
- embedding_func=self.embedding_func,
149
- meta_fields={"src_id", "tgt_id"}
150
- )
151
  )
152
- self.chunks_vdb = (
153
- self.vector_db_storage_cls(
154
- namespace="chunks",
155
- global_config=asdict(self),
156
- embedding_func=self.embedding_func,
157
- )
158
  )
159
-
160
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
161
  partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
162
  )
@@ -177,7 +173,7 @@ class LightRAG:
177
  _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
178
  new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
179
  if not len(new_docs):
180
- logger.warning(f"All docs are already in the storage")
181
  return
182
  logger.info(f"[New Docs] inserting {len(new_docs)} docs")
183
 
@@ -203,7 +199,7 @@ class LightRAG:
203
  k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
204
  }
205
  if not len(inserting_chunks):
206
- logger.warning(f"All chunks are already in the storage")
207
  return
208
  logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
209
 
@@ -246,7 +242,7 @@ class LightRAG:
246
  def query(self, query: str, param: QueryParam = QueryParam()):
247
  loop = always_get_an_event_loop()
248
  return loop.run_until_complete(self.aquery(query, param))
249
-
250
  async def aquery(self, query: str, param: QueryParam = QueryParam()):
251
  if param.mode == "local":
252
  response = await local_query(
@@ -290,7 +286,6 @@ class LightRAG:
290
  raise ValueError(f"Unknown mode {param.mode}")
291
  await self._query_done()
292
  return response
293
-
294
 
295
  async def _query_done(self):
296
  tasks = []
@@ -299,5 +294,3 @@ class LightRAG:
299
  continue
300
  tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
301
  await asyncio.gather(*tasks)
302
-
303
-
 
3
  from dataclasses import asdict, dataclass, field
4
  from datetime import datetime
5
  from functools import partial
6
+ from typing import Type, cast
 
7
 
8
+ from .llm import (
9
+ gpt_4o_mini_complete,
10
+ openai_embedding,
11
+ )
12
  from .operate import (
13
  chunking_by_token_size,
14
  extract_entities,
 
39
  QueryParam,
40
  )
41
 
42
+
43
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
44
  try:
45
  loop = asyncio.get_running_loop()
 
72
  "dimensions": 1536,
73
  "num_walks": 10,
74
  "walk_length": 40,
 
75
  "window_size": 2,
76
  "iterations": 3,
77
  "random_seed": 3,
 
79
  )
80
 
81
  # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
82
+ embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
83
  embedding_batch_num: int = 32
84
  embedding_func_max_async: int = 16
85
 
86
  # LLM
87
+ llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
88
+ llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
89
  llm_model_max_token_size: int = 32768
90
  llm_model_max_async: int = 16
91
 
 
100
  addon_params: dict = field(default_factory=dict)
101
  convert_response_to_json_func: callable = convert_response_to_json
102
 
103
+ def __post_init__(self):
104
  log_file = os.path.join(self.working_dir, "lightrag.log")
105
  set_logger(log_file)
106
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
107
+
108
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
109
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
110
 
 
135
  self.embedding_func
136
  )
137
 
138
+ self.entities_vdb = self.vector_db_storage_cls(
139
+ namespace="entities",
140
+ global_config=asdict(self),
141
+ embedding_func=self.embedding_func,
142
+ meta_fields={"entity_name"},
 
 
143
  )
144
+ self.relationships_vdb = self.vector_db_storage_cls(
145
+ namespace="relationships",
146
+ global_config=asdict(self),
147
+ embedding_func=self.embedding_func,
148
+ meta_fields={"src_id", "tgt_id"},
 
 
149
  )
150
+ self.chunks_vdb = self.vector_db_storage_cls(
151
+ namespace="chunks",
152
+ global_config=asdict(self),
153
+ embedding_func=self.embedding_func,
 
 
154
  )
155
+
156
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
157
  partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
158
  )
 
173
  _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
174
  new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
175
  if not len(new_docs):
176
+ logger.warning("All docs are already in the storage")
177
  return
178
  logger.info(f"[New Docs] inserting {len(new_docs)} docs")
179
 
 
199
  k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
200
  }
201
  if not len(inserting_chunks):
202
+ logger.warning("All chunks are already in the storage")
203
  return
204
  logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
205
 
 
242
  def query(self, query: str, param: QueryParam = QueryParam()):
243
  loop = always_get_an_event_loop()
244
  return loop.run_until_complete(self.aquery(query, param))
245
+
246
  async def aquery(self, query: str, param: QueryParam = QueryParam()):
247
  if param.mode == "local":
248
  response = await local_query(
 
286
  raise ValueError(f"Unknown mode {param.mode}")
287
  await self._query_done()
288
  return response
 
289
 
290
  async def _query_done(self):
291
  tasks = []
 
294
  continue
295
  tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
296
  await asyncio.gather(*tasks)
 
 
lightrag/llm.py CHANGED
@@ -1,9 +1,7 @@
1
  import os
2
  import copy
3
  import json
4
- import botocore
5
  import aioboto3
6
- import botocore.errorfactory
7
  import numpy as np
8
  import ollama
9
  from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
@@ -13,24 +11,34 @@ from tenacity import (
13
  wait_exponential,
14
  retry_if_exception_type,
15
  )
16
- from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
17
  import torch
18
  from .base import BaseKVStorage
19
  from .utils import compute_args_hash, wrap_embedding_func_with_attrs
20
- import copy
21
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
 
22
  @retry(
23
  stop=stop_after_attempt(3),
24
  wait=wait_exponential(multiplier=1, min=4, max=10),
25
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
26
  )
27
  async def openai_complete_if_cache(
28
- model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, **kwargs
 
 
 
 
 
 
29
  ) -> str:
30
  if api_key:
31
  os.environ["OPENAI_API_KEY"] = api_key
32
 
33
- openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
 
 
34
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
35
  messages = []
36
  if system_prompt:
@@ -64,43 +72,56 @@ class BedrockError(Exception):
64
  retry=retry_if_exception_type((BedrockError)),
65
  )
66
  async def bedrock_complete_if_cache(
67
- model, prompt, system_prompt=None, history_messages=[],
68
- aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs
 
 
 
 
 
 
69
  ) -> str:
70
- os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
71
- os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
72
- os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
 
 
 
 
 
 
73
 
74
  # Fix message history format
75
  messages = []
76
  for history_message in history_messages:
77
  message = copy.copy(history_message)
78
- message['content'] = [{'text': message['content']}]
79
  messages.append(message)
80
 
81
  # Add user prompt
82
- messages.append({'role': "user", 'content': [{'text': prompt}]})
83
 
84
  # Initialize Converse API arguments
85
- args = {
86
- 'modelId': model,
87
- 'messages': messages
88
- }
89
 
90
  # Define system prompt
91
  if system_prompt:
92
- args['system'] = [{'text': system_prompt}]
93
 
94
  # Map and set up inference parameters
95
  inference_params_map = {
96
- 'max_tokens': "maxTokens",
97
- 'top_p': "topP",
98
- 'stop_sequences': "stopSequences"
99
  }
100
- if (inference_params := list(set(kwargs) & set(['max_tokens', 'temperature', 'top_p', 'stop_sequences']))):
101
- args['inferenceConfig'] = {}
 
 
102
  for param in inference_params:
103
- args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param)
 
 
104
 
105
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
106
  if hashing_kv is not None:
@@ -112,31 +133,33 @@ async def bedrock_complete_if_cache(
112
  # Call model via Converse API
113
  session = aioboto3.Session()
114
  async with session.client("bedrock-runtime") as bedrock_async_client:
115
-
116
  try:
117
  response = await bedrock_async_client.converse(**args, **kwargs)
118
  except Exception as e:
119
  raise BedrockError(e)
120
 
121
  if hashing_kv is not None:
122
- await hashing_kv.upsert({
123
- args_hash: {
124
- 'return': response['output']['message']['content'][0]['text'],
125
- 'model': model
 
 
126
  }
127
- })
 
 
128
 
129
- return response['output']['message']['content'][0]['text']
130
 
131
  async def hf_model_if_cache(
132
  model, prompt, system_prompt=None, history_messages=[], **kwargs
133
  ) -> str:
134
  model_name = model
135
- hf_tokenizer = AutoTokenizer.from_pretrained(model_name,device_map = 'auto')
136
- if hf_tokenizer.pad_token == None:
137
  # print("use eos token")
138
  hf_tokenizer.pad_token = hf_tokenizer.eos_token
139
- hf_model = AutoModelForCausalLM.from_pretrained(model_name,device_map = 'auto')
140
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
141
  messages = []
142
  if system_prompt:
@@ -149,30 +172,51 @@ async def hf_model_if_cache(
149
  if_cache_return = await hashing_kv.get_by_id(args_hash)
150
  if if_cache_return is not None:
151
  return if_cache_return["return"]
152
- input_prompt = ''
153
  try:
154
- input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
155
- except:
 
 
156
  try:
157
  ori_message = copy.deepcopy(messages)
158
- if messages[0]['role'] == "system":
159
- messages[1]['content'] = "<system>" + messages[0]['content'] + "</system>\n" + messages[1]['content']
 
 
 
 
 
160
  messages = messages[1:]
161
- input_prompt = hf_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
162
- except:
 
 
163
  len_message = len(ori_message)
164
  for msgid in range(len_message):
165
- input_prompt =input_prompt+ '<'+ori_message[msgid]['role']+'>'+ori_message[msgid]['content']+'</'+ori_message[msgid]['role']+'>\n'
166
-
167
- input_ids = hf_tokenizer(input_prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")
168
- output = hf_model.generate(**input_ids, max_new_tokens=200, num_return_sequences=1,early_stopping = True)
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
170
  if hashing_kv is not None:
171
- await hashing_kv.upsert(
172
- {args_hash: {"return": response_text, "model": model}}
173
- )
174
  return response_text
175
 
 
176
  async def ollama_model_if_cache(
177
  model, prompt, system_prompt=None, history_messages=[], **kwargs
178
  ) -> str:
@@ -202,6 +246,7 @@ async def ollama_model_if_cache(
202
 
203
  return result
204
 
 
205
  async def gpt_4o_complete(
206
  prompt, system_prompt=None, history_messages=[], **kwargs
207
  ) -> str:
@@ -241,7 +286,7 @@ async def bedrock_complete(
241
  async def hf_model_complete(
242
  prompt, system_prompt=None, history_messages=[], **kwargs
243
  ) -> str:
244
- model_name = kwargs['hashing_kv'].global_config['llm_model_name']
245
  return await hf_model_if_cache(
246
  model_name,
247
  prompt,
@@ -250,10 +295,11 @@ async def hf_model_complete(
250
  **kwargs,
251
  )
252
 
 
253
  async def ollama_model_complete(
254
  prompt, system_prompt=None, history_messages=[], **kwargs
255
  ) -> str:
256
- model_name = kwargs['hashing_kv'].global_config['llm_model_name']
257
  return await ollama_model_if_cache(
258
  model_name,
259
  prompt,
@@ -262,17 +308,25 @@ async def ollama_model_complete(
262
  **kwargs,
263
  )
264
 
 
265
  @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
266
  @retry(
267
  stop=stop_after_attempt(3),
268
  wait=wait_exponential(multiplier=1, min=4, max=10),
269
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
270
  )
271
- async def openai_embedding(texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None) -> np.ndarray:
 
 
 
 
 
272
  if api_key:
273
  os.environ["OPENAI_API_KEY"] = api_key
274
 
275
- openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
 
 
276
  response = await openai_async_client.embeddings.create(
277
  model=model, input=texts, encoding_format="float"
278
  )
@@ -286,28 +340,37 @@ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-smal
286
  # retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
287
  # )
288
  async def bedrock_embedding(
289
- texts: list[str], model: str = "amazon.titan-embed-text-v2:0",
290
- aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) -> np.ndarray:
291
- os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
292
- os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
293
- os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
 
 
 
 
 
 
 
 
 
 
294
 
295
  session = aioboto3.Session()
296
  async with session.client("bedrock-runtime") as bedrock_async_client:
297
-
298
  if (model_provider := model.split(".")[0]) == "amazon":
299
  embed_texts = []
300
  for text in texts:
301
  if "v2" in model:
302
- body = json.dumps({
303
- 'inputText': text,
304
- # 'dimensions': embedding_dim,
305
- 'embeddingTypes': ["float"]
306
- })
 
 
307
  elif "v1" in model:
308
- body = json.dumps({
309
- 'inputText': text
310
- })
311
  else:
312
  raise ValueError(f"Model {model} is not supported!")
313
 
@@ -315,29 +378,27 @@ async def bedrock_embedding(
315
  modelId=model,
316
  body=body,
317
  accept="application/json",
318
- contentType="application/json"
319
  )
320
 
321
- response_body = await response.get('body').json()
322
 
323
- embed_texts.append(response_body['embedding'])
324
  elif model_provider == "cohere":
325
- body = json.dumps({
326
- 'texts': texts,
327
- 'input_type': "search_document",
328
- 'truncate': "NONE"
329
- })
330
 
331
  response = await bedrock_async_client.invoke_model(
332
  model=model,
333
  body=body,
334
  accept="application/json",
335
- contentType="application/json"
336
  )
337
 
338
- response_body = json.loads(response.get('body').read())
339
 
340
- embed_texts = response_body['embeddings']
341
  else:
342
  raise ValueError(f"Model provider '{model_provider}' is not supported!")
343
 
@@ -345,12 +406,15 @@ async def bedrock_embedding(
345
 
346
 
347
  async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
348
- input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
 
 
349
  with torch.no_grad():
350
  outputs = embed_model(input_ids)
351
  embeddings = outputs.last_hidden_state.mean(dim=1)
352
  return embeddings.detach().numpy()
353
 
 
354
  async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
355
  embed_text = []
356
  for text in texts:
@@ -359,11 +423,12 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
359
 
360
  return embed_text
361
 
 
362
  if __name__ == "__main__":
363
  import asyncio
364
 
365
  async def main():
366
- result = await gpt_4o_mini_complete('How are you?')
367
  print(result)
368
 
369
  asyncio.run(main())
 
1
  import os
2
  import copy
3
  import json
 
4
  import aioboto3
 
5
  import numpy as np
6
  import ollama
7
  from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
 
11
  wait_exponential,
12
  retry_if_exception_type,
13
  )
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM
15
  import torch
16
  from .base import BaseKVStorage
17
  from .utils import compute_args_hash, wrap_embedding_func_with_attrs
18
+
19
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
+
21
+
22
  @retry(
23
  stop=stop_after_attempt(3),
24
  wait=wait_exponential(multiplier=1, min=4, max=10),
25
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
26
  )
27
  async def openai_complete_if_cache(
28
+ model,
29
+ prompt,
30
+ system_prompt=None,
31
+ history_messages=[],
32
+ base_url=None,
33
+ api_key=None,
34
+ **kwargs,
35
  ) -> str:
36
  if api_key:
37
  os.environ["OPENAI_API_KEY"] = api_key
38
 
39
+ openai_async_client = (
40
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
41
+ )
42
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
43
  messages = []
44
  if system_prompt:
 
72
  retry=retry_if_exception_type((BedrockError)),
73
  )
74
  async def bedrock_complete_if_cache(
75
+ model,
76
+ prompt,
77
+ system_prompt=None,
78
+ history_messages=[],
79
+ aws_access_key_id=None,
80
+ aws_secret_access_key=None,
81
+ aws_session_token=None,
82
+ **kwargs,
83
  ) -> str:
84
+ os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
85
+ "AWS_ACCESS_KEY_ID", aws_access_key_id
86
+ )
87
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
88
+ "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
89
+ )
90
+ os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
91
+ "AWS_SESSION_TOKEN", aws_session_token
92
+ )
93
 
94
  # Fix message history format
95
  messages = []
96
  for history_message in history_messages:
97
  message = copy.copy(history_message)
98
+ message["content"] = [{"text": message["content"]}]
99
  messages.append(message)
100
 
101
  # Add user prompt
102
+ messages.append({"role": "user", "content": [{"text": prompt}]})
103
 
104
  # Initialize Converse API arguments
105
+ args = {"modelId": model, "messages": messages}
 
 
 
106
 
107
  # Define system prompt
108
  if system_prompt:
109
+ args["system"] = [{"text": system_prompt}]
110
 
111
  # Map and set up inference parameters
112
  inference_params_map = {
113
+ "max_tokens": "maxTokens",
114
+ "top_p": "topP",
115
+ "stop_sequences": "stopSequences",
116
  }
117
+ if inference_params := list(
118
+ set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
119
+ ):
120
+ args["inferenceConfig"] = {}
121
  for param in inference_params:
122
+ args["inferenceConfig"][inference_params_map.get(param, param)] = (
123
+ kwargs.pop(param)
124
+ )
125
 
126
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
127
  if hashing_kv is not None:
 
133
  # Call model via Converse API
134
  session = aioboto3.Session()
135
  async with session.client("bedrock-runtime") as bedrock_async_client:
 
136
  try:
137
  response = await bedrock_async_client.converse(**args, **kwargs)
138
  except Exception as e:
139
  raise BedrockError(e)
140
 
141
  if hashing_kv is not None:
142
+ await hashing_kv.upsert(
143
+ {
144
+ args_hash: {
145
+ "return": response["output"]["message"]["content"][0]["text"],
146
+ "model": model,
147
+ }
148
  }
149
+ )
150
+
151
+ return response["output"]["message"]["content"][0]["text"]
152
 
 
153
 
154
  async def hf_model_if_cache(
155
  model, prompt, system_prompt=None, history_messages=[], **kwargs
156
  ) -> str:
157
  model_name = model
158
+ hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
159
+ if hf_tokenizer.pad_token is None:
160
  # print("use eos token")
161
  hf_tokenizer.pad_token = hf_tokenizer.eos_token
162
+ hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
163
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
164
  messages = []
165
  if system_prompt:
 
172
  if_cache_return = await hashing_kv.get_by_id(args_hash)
173
  if if_cache_return is not None:
174
  return if_cache_return["return"]
175
+ input_prompt = ""
176
  try:
177
+ input_prompt = hf_tokenizer.apply_chat_template(
178
+ messages, tokenize=False, add_generation_prompt=True
179
+ )
180
+ except Exception:
181
  try:
182
  ori_message = copy.deepcopy(messages)
183
+ if messages[0]["role"] == "system":
184
+ messages[1]["content"] = (
185
+ "<system>"
186
+ + messages[0]["content"]
187
+ + "</system>\n"
188
+ + messages[1]["content"]
189
+ )
190
  messages = messages[1:]
191
+ input_prompt = hf_tokenizer.apply_chat_template(
192
+ messages, tokenize=False, add_generation_prompt=True
193
+ )
194
+ except Exception:
195
  len_message = len(ori_message)
196
  for msgid in range(len_message):
197
+ input_prompt = (
198
+ input_prompt
199
+ + "<"
200
+ + ori_message[msgid]["role"]
201
+ + ">"
202
+ + ori_message[msgid]["content"]
203
+ + "</"
204
+ + ori_message[msgid]["role"]
205
+ + ">\n"
206
+ )
207
+
208
+ input_ids = hf_tokenizer(
209
+ input_prompt, return_tensors="pt", padding=True, truncation=True
210
+ ).to("cuda")
211
+ output = hf_model.generate(
212
+ **input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
213
+ )
214
  response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
215
  if hashing_kv is not None:
216
+ await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
 
 
217
  return response_text
218
 
219
+
220
  async def ollama_model_if_cache(
221
  model, prompt, system_prompt=None, history_messages=[], **kwargs
222
  ) -> str:
 
246
 
247
  return result
248
 
249
+
250
  async def gpt_4o_complete(
251
  prompt, system_prompt=None, history_messages=[], **kwargs
252
  ) -> str:
 
286
  async def hf_model_complete(
287
  prompt, system_prompt=None, history_messages=[], **kwargs
288
  ) -> str:
289
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
290
  return await hf_model_if_cache(
291
  model_name,
292
  prompt,
 
295
  **kwargs,
296
  )
297
 
298
+
299
  async def ollama_model_complete(
300
  prompt, system_prompt=None, history_messages=[], **kwargs
301
  ) -> str:
302
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
303
  return await ollama_model_if_cache(
304
  model_name,
305
  prompt,
 
308
  **kwargs,
309
  )
310
 
311
+
312
  @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
313
  @retry(
314
  stop=stop_after_attempt(3),
315
  wait=wait_exponential(multiplier=1, min=4, max=10),
316
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
317
  )
318
+ async def openai_embedding(
319
+ texts: list[str],
320
+ model: str = "text-embedding-3-small",
321
+ base_url: str = None,
322
+ api_key: str = None,
323
+ ) -> np.ndarray:
324
  if api_key:
325
  os.environ["OPENAI_API_KEY"] = api_key
326
 
327
+ openai_async_client = (
328
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
329
+ )
330
  response = await openai_async_client.embeddings.create(
331
  model=model, input=texts, encoding_format="float"
332
  )
 
340
  # retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
341
  # )
342
  async def bedrock_embedding(
343
+ texts: list[str],
344
+ model: str = "amazon.titan-embed-text-v2:0",
345
+ aws_access_key_id=None,
346
+ aws_secret_access_key=None,
347
+ aws_session_token=None,
348
+ ) -> np.ndarray:
349
+ os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
350
+ "AWS_ACCESS_KEY_ID", aws_access_key_id
351
+ )
352
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
353
+ "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
354
+ )
355
+ os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
356
+ "AWS_SESSION_TOKEN", aws_session_token
357
+ )
358
 
359
  session = aioboto3.Session()
360
  async with session.client("bedrock-runtime") as bedrock_async_client:
 
361
  if (model_provider := model.split(".")[0]) == "amazon":
362
  embed_texts = []
363
  for text in texts:
364
  if "v2" in model:
365
+ body = json.dumps(
366
+ {
367
+ "inputText": text,
368
+ # 'dimensions': embedding_dim,
369
+ "embeddingTypes": ["float"],
370
+ }
371
+ )
372
  elif "v1" in model:
373
+ body = json.dumps({"inputText": text})
 
 
374
  else:
375
  raise ValueError(f"Model {model} is not supported!")
376
 
 
378
  modelId=model,
379
  body=body,
380
  accept="application/json",
381
+ contentType="application/json",
382
  )
383
 
384
+ response_body = await response.get("body").json()
385
 
386
+ embed_texts.append(response_body["embedding"])
387
  elif model_provider == "cohere":
388
+ body = json.dumps(
389
+ {"texts": texts, "input_type": "search_document", "truncate": "NONE"}
390
+ )
 
 
391
 
392
  response = await bedrock_async_client.invoke_model(
393
  model=model,
394
  body=body,
395
  accept="application/json",
396
+ contentType="application/json",
397
  )
398
 
399
+ response_body = json.loads(response.get("body").read())
400
 
401
+ embed_texts = response_body["embeddings"]
402
  else:
403
  raise ValueError(f"Model provider '{model_provider}' is not supported!")
404
 
 
406
 
407
 
408
  async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
409
+ input_ids = tokenizer(
410
+ texts, return_tensors="pt", padding=True, truncation=True
411
+ ).input_ids
412
  with torch.no_grad():
413
  outputs = embed_model(input_ids)
414
  embeddings = outputs.last_hidden_state.mean(dim=1)
415
  return embeddings.detach().numpy()
416
 
417
+
418
  async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
419
  embed_text = []
420
  for text in texts:
 
423
 
424
  return embed_text
425
 
426
+
427
  if __name__ == "__main__":
428
  import asyncio
429
 
430
  async def main():
431
+ result = await gpt_4o_mini_complete("How are you?")
432
  print(result)
433
 
434
  asyncio.run(main())
lightrag/operate.py CHANGED
@@ -25,6 +25,7 @@ from .base import (
25
  )
26
  from .prompt import GRAPH_FIELD_SEP, PROMPTS
27
 
 
28
  def chunking_by_token_size(
29
  content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
30
  ):
@@ -45,6 +46,7 @@ def chunking_by_token_size(
45
  )
46
  return results
47
 
 
48
  async def _handle_entity_relation_summary(
49
  entity_or_relation_name: str,
50
  description: str,
@@ -229,9 +231,10 @@ async def _merge_edges_then_upsert(
229
  description=description,
230
  keywords=keywords,
231
  )
232
-
233
  return edge_data
234
 
 
235
  async def extract_entities(
236
  chunks: dict[str, TextChunkSchema],
237
  knwoledge_graph_inst: BaseGraphStorage,
@@ -352,7 +355,9 @@ async def extract_entities(
352
  logger.warning("Didn't extract any entities, maybe your LLM is not working")
353
  return None
354
  if not len(all_relationships_data):
355
- logger.warning("Didn't extract any relationships, maybe your LLM is not working")
 
 
356
  return None
357
 
358
  if entity_vdb is not None:
@@ -370,7 +375,10 @@ async def extract_entities(
370
  compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
371
  "src_id": dp["src_id"],
372
  "tgt_id": dp["tgt_id"],
373
- "content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"],
 
 
 
374
  }
375
  for dp in all_relationships_data
376
  }
@@ -378,6 +386,7 @@ async def extract_entities(
378
 
379
  return knwoledge_graph_inst
380
 
 
381
  async def local_query(
382
  query,
383
  knowledge_graph_inst: BaseGraphStorage,
@@ -393,19 +402,24 @@ async def local_query(
393
  kw_prompt_temp = PROMPTS["keywords_extraction"]
394
  kw_prompt = kw_prompt_temp.format(query=query)
395
  result = await use_model_func(kw_prompt)
396
-
397
  try:
398
  keywords_data = json.loads(result)
399
  keywords = keywords_data.get("low_level_keywords", [])
400
- keywords = ', '.join(keywords)
401
- except json.JSONDecodeError as e:
402
  try:
403
- result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
404
- result = '{' + result.split('{')[1].split('}')[0] + '}'
 
 
 
 
 
405
 
406
  keywords_data = json.loads(result)
407
  keywords = keywords_data.get("low_level_keywords", [])
408
- keywords = ', '.join(keywords)
409
  # Handle parsing error
410
  except json.JSONDecodeError as e:
411
  print(f"JSON parsing error: {e}")
@@ -430,11 +444,20 @@ async def local_query(
430
  query,
431
  system_prompt=sys_prompt,
432
  )
433
- if len(response)>len(sys_prompt):
434
- response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
435
-
 
 
 
 
 
 
 
 
436
  return response
437
 
 
438
  async def _build_local_query_context(
439
  query,
440
  knowledge_graph_inst: BaseGraphStorage,
@@ -516,6 +539,7 @@ async def _build_local_query_context(
516
  ```
517
  """
518
 
 
519
  async def _find_most_related_text_unit_from_entities(
520
  node_datas: list[dict],
521
  query_param: QueryParam,
@@ -576,6 +600,7 @@ async def _find_most_related_text_unit_from_entities(
576
  all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
577
  return all_text_units
578
 
 
579
  async def _find_most_related_edges_from_entities(
580
  node_datas: list[dict],
581
  query_param: QueryParam,
@@ -609,6 +634,7 @@ async def _find_most_related_edges_from_entities(
609
  )
610
  return all_edges_data
611
 
 
612
  async def global_query(
613
  query,
614
  knowledge_graph_inst: BaseGraphStorage,
@@ -624,20 +650,25 @@ async def global_query(
624
  kw_prompt_temp = PROMPTS["keywords_extraction"]
625
  kw_prompt = kw_prompt_temp.format(query=query)
626
  result = await use_model_func(kw_prompt)
627
-
628
  try:
629
  keywords_data = json.loads(result)
630
  keywords = keywords_data.get("high_level_keywords", [])
631
- keywords = ', '.join(keywords)
632
- except json.JSONDecodeError as e:
633
  try:
634
- result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
635
- result = '{' + result.split('{')[1].split('}')[0] + '}'
 
 
 
 
 
636
 
637
  keywords_data = json.loads(result)
638
  keywords = keywords_data.get("high_level_keywords", [])
639
- keywords = ', '.join(keywords)
640
-
641
  except json.JSONDecodeError as e:
642
  # Handle parsing error
643
  print(f"JSON parsing error: {e}")
@@ -651,12 +682,12 @@ async def global_query(
651
  text_chunks_db,
652
  query_param,
653
  )
654
-
655
  if query_param.only_need_context:
656
  return context
657
  if context is None:
658
  return PROMPTS["fail_response"]
659
-
660
  sys_prompt_temp = PROMPTS["rag_response"]
661
  sys_prompt = sys_prompt_temp.format(
662
  context_data=context, response_type=query_param.response_type
@@ -665,11 +696,20 @@ async def global_query(
665
  query,
666
  system_prompt=sys_prompt,
667
  )
668
- if len(response)>len(sys_prompt):
669
- response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
670
-
 
 
 
 
 
 
 
 
671
  return response
672
 
 
673
  async def _build_global_query_context(
674
  keywords,
675
  knowledge_graph_inst: BaseGraphStorage,
@@ -679,14 +719,14 @@ async def _build_global_query_context(
679
  query_param: QueryParam,
680
  ):
681
  results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
682
-
683
  if not len(results):
684
  return None
685
-
686
  edge_datas = await asyncio.gather(
687
  *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
688
  )
689
-
690
  if not all([n is not None for n in edge_datas]):
691
  logger.warning("Some edges are missing, maybe the storage is damaged")
692
  edge_degree = await asyncio.gather(
@@ -765,6 +805,7 @@ async def _build_global_query_context(
765
  ```
766
  """
767
 
 
768
  async def _find_most_related_entities_from_relationships(
769
  edge_datas: list[dict],
770
  query_param: QueryParam,
@@ -774,7 +815,7 @@ async def _find_most_related_entities_from_relationships(
774
  for e in edge_datas:
775
  entity_names.add(e["src_id"])
776
  entity_names.add(e["tgt_id"])
777
-
778
  node_datas = await asyncio.gather(
779
  *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
780
  )
@@ -795,13 +836,13 @@ async def _find_most_related_entities_from_relationships(
795
 
796
  return node_datas
797
 
 
798
  async def _find_related_text_unit_from_relationships(
799
  edge_datas: list[dict],
800
  query_param: QueryParam,
801
  text_chunks_db: BaseKVStorage[TextChunkSchema],
802
  knowledge_graph_inst: BaseGraphStorage,
803
  ):
804
-
805
  text_units = [
806
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
807
  for dp in edge_datas
@@ -816,15 +857,13 @@ async def _find_related_text_unit_from_relationships(
816
  "data": await text_chunks_db.get_by_id(c_id),
817
  "order": index,
818
  }
819
-
820
  if any([v is None for v in all_text_units_lookup.values()]):
821
  logger.warning("Text chunks are missing, maybe the storage is damaged")
822
  all_text_units = [
823
  {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
824
  ]
825
- all_text_units = sorted(
826
- all_text_units, key=lambda x: x["order"]
827
- )
828
  all_text_units = truncate_list_by_token_size(
829
  all_text_units,
830
  key=lambda x: x["data"]["content"],
@@ -834,6 +873,7 @@ async def _find_related_text_unit_from_relationships(
834
 
835
  return all_text_units
836
 
 
837
  async def hybrid_query(
838
  query,
839
  knowledge_graph_inst: BaseGraphStorage,
@@ -849,24 +889,29 @@ async def hybrid_query(
849
 
850
  kw_prompt_temp = PROMPTS["keywords_extraction"]
851
  kw_prompt = kw_prompt_temp.format(query=query)
852
-
853
  result = await use_model_func(kw_prompt)
854
  try:
855
  keywords_data = json.loads(result)
856
  hl_keywords = keywords_data.get("high_level_keywords", [])
857
  ll_keywords = keywords_data.get("low_level_keywords", [])
858
- hl_keywords = ', '.join(hl_keywords)
859
- ll_keywords = ', '.join(ll_keywords)
860
- except json.JSONDecodeError as e:
861
  try:
862
- result = result.replace(kw_prompt[:-1],'').replace('user','').replace('model','').strip()
863
- result = '{' + result.split('{')[1].split('}')[0] + '}'
 
 
 
 
 
864
 
865
  keywords_data = json.loads(result)
866
  hl_keywords = keywords_data.get("high_level_keywords", [])
867
  ll_keywords = keywords_data.get("low_level_keywords", [])
868
- hl_keywords = ', '.join(hl_keywords)
869
- ll_keywords = ', '.join(ll_keywords)
870
  # Handle parsing error
871
  except json.JSONDecodeError as e:
872
  print(f"JSON parsing error: {e}")
@@ -897,7 +942,7 @@ async def hybrid_query(
897
  return context
898
  if context is None:
899
  return PROMPTS["fail_response"]
900
-
901
  sys_prompt_temp = PROMPTS["rag_response"]
902
  sys_prompt = sys_prompt_temp.format(
903
  context_data=context, response_type=query_param.response_type
@@ -906,53 +951,78 @@ async def hybrid_query(
906
  query,
907
  system_prompt=sys_prompt,
908
  )
909
- if len(response)>len(sys_prompt):
910
- response = response.replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
 
 
 
 
 
 
 
 
911
  return response
912
 
 
913
  def combine_contexts(high_level_context, low_level_context):
914
  # Function to extract entities, relationships, and sources from context strings
915
 
916
  def extract_sections(context):
917
- entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
918
- relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
919
- sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
920
-
921
- entities = entities_match.group(1) if entities_match else ''
922
- relationships = relationships_match.group(1) if relationships_match else ''
923
- sources = sources_match.group(1) if sources_match else ''
924
-
 
 
 
 
 
 
925
  return entities, relationships, sources
926
-
927
  # Extract sections from both contexts
928
 
929
- if high_level_context==None:
930
- warnings.warn("High Level context is None. Return empty High entity/relationship/source")
931
- hl_entities, hl_relationships, hl_sources = '','',''
 
 
932
  else:
933
  hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
934
 
935
-
936
- if low_level_context==None:
937
- warnings.warn("Low Level context is None. Return empty Low entity/relationship/source")
938
- ll_entities, ll_relationships, ll_sources = '','',''
 
939
  else:
940
  ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
941
 
942
-
943
-
944
  # Combine and deduplicate the entities
945
- combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
946
- combined_entities = '\n'.join(combined_entities_set)
947
-
 
 
948
  # Combine and deduplicate the relationships
949
- combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
950
- combined_relationships = '\n'.join(combined_relationships_set)
951
-
 
 
 
 
 
952
  # Combine and deduplicate the sources
953
- combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
954
- combined_sources = '\n'.join(combined_sources_set)
955
-
 
 
956
  # Format the combined context
957
  return f"""
958
  -----Entities-----
@@ -964,6 +1034,7 @@ def combine_contexts(high_level_context, low_level_context):
964
  {combined_sources}
965
  """
966
 
 
967
  async def naive_query(
968
  query,
969
  chunks_vdb: BaseVectorStorage,
@@ -996,8 +1067,16 @@ async def naive_query(
996
  system_prompt=sys_prompt,
997
  )
998
 
999
- if len(response)>len(sys_prompt):
1000
- response = response[len(sys_prompt):].replace(sys_prompt,'').replace('user','').replace('model','').replace(query,'').replace('<system>','').replace('</system>','').strip()
1001
-
1002
- return response
 
 
 
 
 
 
 
1003
 
 
 
25
  )
26
  from .prompt import GRAPH_FIELD_SEP, PROMPTS
27
 
28
+
29
  def chunking_by_token_size(
30
  content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
31
  ):
 
46
  )
47
  return results
48
 
49
+
50
  async def _handle_entity_relation_summary(
51
  entity_or_relation_name: str,
52
  description: str,
 
231
  description=description,
232
  keywords=keywords,
233
  )
234
+
235
  return edge_data
236
 
237
+
238
  async def extract_entities(
239
  chunks: dict[str, TextChunkSchema],
240
  knwoledge_graph_inst: BaseGraphStorage,
 
355
  logger.warning("Didn't extract any entities, maybe your LLM is not working")
356
  return None
357
  if not len(all_relationships_data):
358
+ logger.warning(
359
+ "Didn't extract any relationships, maybe your LLM is not working"
360
+ )
361
  return None
362
 
363
  if entity_vdb is not None:
 
375
  compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
376
  "src_id": dp["src_id"],
377
  "tgt_id": dp["tgt_id"],
378
+ "content": dp["keywords"]
379
+ + dp["src_id"]
380
+ + dp["tgt_id"]
381
+ + dp["description"],
382
  }
383
  for dp in all_relationships_data
384
  }
 
386
 
387
  return knwoledge_graph_inst
388
 
389
+
390
  async def local_query(
391
  query,
392
  knowledge_graph_inst: BaseGraphStorage,
 
402
  kw_prompt_temp = PROMPTS["keywords_extraction"]
403
  kw_prompt = kw_prompt_temp.format(query=query)
404
  result = await use_model_func(kw_prompt)
405
+
406
  try:
407
  keywords_data = json.loads(result)
408
  keywords = keywords_data.get("low_level_keywords", [])
409
+ keywords = ", ".join(keywords)
410
+ except json.JSONDecodeError:
411
  try:
412
+ result = (
413
+ result.replace(kw_prompt[:-1], "")
414
+ .replace("user", "")
415
+ .replace("model", "")
416
+ .strip()
417
+ )
418
+ result = "{" + result.split("{")[1].split("}")[0] + "}"
419
 
420
  keywords_data = json.loads(result)
421
  keywords = keywords_data.get("low_level_keywords", [])
422
+ keywords = ", ".join(keywords)
423
  # Handle parsing error
424
  except json.JSONDecodeError as e:
425
  print(f"JSON parsing error: {e}")
 
444
  query,
445
  system_prompt=sys_prompt,
446
  )
447
+ if len(response) > len(sys_prompt):
448
+ response = (
449
+ response.replace(sys_prompt, "")
450
+ .replace("user", "")
451
+ .replace("model", "")
452
+ .replace(query, "")
453
+ .replace("<system>", "")
454
+ .replace("</system>", "")
455
+ .strip()
456
+ )
457
+
458
  return response
459
 
460
+
461
  async def _build_local_query_context(
462
  query,
463
  knowledge_graph_inst: BaseGraphStorage,
 
539
  ```
540
  """
541
 
542
+
543
  async def _find_most_related_text_unit_from_entities(
544
  node_datas: list[dict],
545
  query_param: QueryParam,
 
600
  all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
601
  return all_text_units
602
 
603
+
604
  async def _find_most_related_edges_from_entities(
605
  node_datas: list[dict],
606
  query_param: QueryParam,
 
634
  )
635
  return all_edges_data
636
 
637
+
638
  async def global_query(
639
  query,
640
  knowledge_graph_inst: BaseGraphStorage,
 
650
  kw_prompt_temp = PROMPTS["keywords_extraction"]
651
  kw_prompt = kw_prompt_temp.format(query=query)
652
  result = await use_model_func(kw_prompt)
653
+
654
  try:
655
  keywords_data = json.loads(result)
656
  keywords = keywords_data.get("high_level_keywords", [])
657
+ keywords = ", ".join(keywords)
658
+ except json.JSONDecodeError:
659
  try:
660
+ result = (
661
+ result.replace(kw_prompt[:-1], "")
662
+ .replace("user", "")
663
+ .replace("model", "")
664
+ .strip()
665
+ )
666
+ result = "{" + result.split("{")[1].split("}")[0] + "}"
667
 
668
  keywords_data = json.loads(result)
669
  keywords = keywords_data.get("high_level_keywords", [])
670
+ keywords = ", ".join(keywords)
671
+
672
  except json.JSONDecodeError as e:
673
  # Handle parsing error
674
  print(f"JSON parsing error: {e}")
 
682
  text_chunks_db,
683
  query_param,
684
  )
685
+
686
  if query_param.only_need_context:
687
  return context
688
  if context is None:
689
  return PROMPTS["fail_response"]
690
+
691
  sys_prompt_temp = PROMPTS["rag_response"]
692
  sys_prompt = sys_prompt_temp.format(
693
  context_data=context, response_type=query_param.response_type
 
696
  query,
697
  system_prompt=sys_prompt,
698
  )
699
+ if len(response) > len(sys_prompt):
700
+ response = (
701
+ response.replace(sys_prompt, "")
702
+ .replace("user", "")
703
+ .replace("model", "")
704
+ .replace(query, "")
705
+ .replace("<system>", "")
706
+ .replace("</system>", "")
707
+ .strip()
708
+ )
709
+
710
  return response
711
 
712
+
713
  async def _build_global_query_context(
714
  keywords,
715
  knowledge_graph_inst: BaseGraphStorage,
 
719
  query_param: QueryParam,
720
  ):
721
  results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
722
+
723
  if not len(results):
724
  return None
725
+
726
  edge_datas = await asyncio.gather(
727
  *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
728
  )
729
+
730
  if not all([n is not None for n in edge_datas]):
731
  logger.warning("Some edges are missing, maybe the storage is damaged")
732
  edge_degree = await asyncio.gather(
 
805
  ```
806
  """
807
 
808
+
809
  async def _find_most_related_entities_from_relationships(
810
  edge_datas: list[dict],
811
  query_param: QueryParam,
 
815
  for e in edge_datas:
816
  entity_names.add(e["src_id"])
817
  entity_names.add(e["tgt_id"])
818
+
819
  node_datas = await asyncio.gather(
820
  *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
821
  )
 
836
 
837
  return node_datas
838
 
839
+
840
  async def _find_related_text_unit_from_relationships(
841
  edge_datas: list[dict],
842
  query_param: QueryParam,
843
  text_chunks_db: BaseKVStorage[TextChunkSchema],
844
  knowledge_graph_inst: BaseGraphStorage,
845
  ):
 
846
  text_units = [
847
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
848
  for dp in edge_datas
 
857
  "data": await text_chunks_db.get_by_id(c_id),
858
  "order": index,
859
  }
860
+
861
  if any([v is None for v in all_text_units_lookup.values()]):
862
  logger.warning("Text chunks are missing, maybe the storage is damaged")
863
  all_text_units = [
864
  {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
865
  ]
866
+ all_text_units = sorted(all_text_units, key=lambda x: x["order"])
 
 
867
  all_text_units = truncate_list_by_token_size(
868
  all_text_units,
869
  key=lambda x: x["data"]["content"],
 
873
 
874
  return all_text_units
875
 
876
+
877
  async def hybrid_query(
878
  query,
879
  knowledge_graph_inst: BaseGraphStorage,
 
889
 
890
  kw_prompt_temp = PROMPTS["keywords_extraction"]
891
  kw_prompt = kw_prompt_temp.format(query=query)
892
+
893
  result = await use_model_func(kw_prompt)
894
  try:
895
  keywords_data = json.loads(result)
896
  hl_keywords = keywords_data.get("high_level_keywords", [])
897
  ll_keywords = keywords_data.get("low_level_keywords", [])
898
+ hl_keywords = ", ".join(hl_keywords)
899
+ ll_keywords = ", ".join(ll_keywords)
900
+ except json.JSONDecodeError:
901
  try:
902
+ result = (
903
+ result.replace(kw_prompt[:-1], "")
904
+ .replace("user", "")
905
+ .replace("model", "")
906
+ .strip()
907
+ )
908
+ result = "{" + result.split("{")[1].split("}")[0] + "}"
909
 
910
  keywords_data = json.loads(result)
911
  hl_keywords = keywords_data.get("high_level_keywords", [])
912
  ll_keywords = keywords_data.get("low_level_keywords", [])
913
+ hl_keywords = ", ".join(hl_keywords)
914
+ ll_keywords = ", ".join(ll_keywords)
915
  # Handle parsing error
916
  except json.JSONDecodeError as e:
917
  print(f"JSON parsing error: {e}")
 
942
  return context
943
  if context is None:
944
  return PROMPTS["fail_response"]
945
+
946
  sys_prompt_temp = PROMPTS["rag_response"]
947
  sys_prompt = sys_prompt_temp.format(
948
  context_data=context, response_type=query_param.response_type
 
951
  query,
952
  system_prompt=sys_prompt,
953
  )
954
+ if len(response) > len(sys_prompt):
955
+ response = (
956
+ response.replace(sys_prompt, "")
957
+ .replace("user", "")
958
+ .replace("model", "")
959
+ .replace(query, "")
960
+ .replace("<system>", "")
961
+ .replace("</system>", "")
962
+ .strip()
963
+ )
964
  return response
965
 
966
+
967
  def combine_contexts(high_level_context, low_level_context):
968
  # Function to extract entities, relationships, and sources from context strings
969
 
970
  def extract_sections(context):
971
+ entities_match = re.search(
972
+ r"-----Entities-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
973
+ )
974
+ relationships_match = re.search(
975
+ r"-----Relationships-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
976
+ )
977
+ sources_match = re.search(
978
+ r"-----Sources-----\s*```csv\s*(.*?)\s*```", context, re.DOTALL
979
+ )
980
+
981
+ entities = entities_match.group(1) if entities_match else ""
982
+ relationships = relationships_match.group(1) if relationships_match else ""
983
+ sources = sources_match.group(1) if sources_match else ""
984
+
985
  return entities, relationships, sources
986
+
987
  # Extract sections from both contexts
988
 
989
+ if high_level_context is None:
990
+ warnings.warn(
991
+ "High Level context is None. Return empty High entity/relationship/source"
992
+ )
993
+ hl_entities, hl_relationships, hl_sources = "", "", ""
994
  else:
995
  hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
996
 
997
+ if low_level_context is None:
998
+ warnings.warn(
999
+ "Low Level context is None. Return empty Low entity/relationship/source"
1000
+ )
1001
+ ll_entities, ll_relationships, ll_sources = "", "", ""
1002
  else:
1003
  ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
1004
 
 
 
1005
  # Combine and deduplicate the entities
1006
+ combined_entities_set = set(
1007
+ filter(None, hl_entities.strip().split("\n") + ll_entities.strip().split("\n"))
1008
+ )
1009
+ combined_entities = "\n".join(combined_entities_set)
1010
+
1011
  # Combine and deduplicate the relationships
1012
+ combined_relationships_set = set(
1013
+ filter(
1014
+ None,
1015
+ hl_relationships.strip().split("\n") + ll_relationships.strip().split("\n"),
1016
+ )
1017
+ )
1018
+ combined_relationships = "\n".join(combined_relationships_set)
1019
+
1020
  # Combine and deduplicate the sources
1021
+ combined_sources_set = set(
1022
+ filter(None, hl_sources.strip().split("\n") + ll_sources.strip().split("\n"))
1023
+ )
1024
+ combined_sources = "\n".join(combined_sources_set)
1025
+
1026
  # Format the combined context
1027
  return f"""
1028
  -----Entities-----
 
1034
  {combined_sources}
1035
  """
1036
 
1037
+
1038
  async def naive_query(
1039
  query,
1040
  chunks_vdb: BaseVectorStorage,
 
1067
  system_prompt=sys_prompt,
1068
  )
1069
 
1070
+ if len(response) > len(sys_prompt):
1071
+ response = (
1072
+ response[len(sys_prompt) :]
1073
+ .replace(sys_prompt, "")
1074
+ .replace("user", "")
1075
+ .replace("model", "")
1076
+ .replace(query, "")
1077
+ .replace("<system>", "")
1078
+ .replace("</system>", "")
1079
+ .strip()
1080
+ )
1081
 
1082
+ return response
lightrag/prompt.py CHANGED
@@ -9,9 +9,7 @@ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "
9
 
10
  PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
11
 
12
- PROMPTS[
13
- "entity_extraction"
14
- ] = """-Goal-
15
  Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
16
 
17
  -Steps-
@@ -32,7 +30,7 @@ Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tupl
32
 
33
  3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
34
  Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
35
-
36
  4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
37
 
38
  5. When finished, output {completion_delimiter}
@@ -146,9 +144,7 @@ PROMPTS[
146
 
147
  PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
148
 
149
- PROMPTS[
150
- "rag_response"
151
- ] = """---Role---
152
 
153
  You are a helpful assistant responding to questions about data in the tables provided.
154
 
@@ -226,9 +222,7 @@ Output:
226
 
227
  """
228
 
229
- PROMPTS[
230
- "naive_rag_response"
231
- ] = """You're a helpful assistant
232
  Below are the knowledge you know:
233
  {content_data}
234
  ---
 
9
 
10
  PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
11
 
12
+ PROMPTS["entity_extraction"] = """-Goal-
 
 
13
  Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
14
 
15
  -Steps-
 
30
 
31
  3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
32
  Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
33
+
34
  4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
35
 
36
  5. When finished, output {completion_delimiter}
 
144
 
145
  PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
146
 
147
+ PROMPTS["rag_response"] = """---Role---
 
 
148
 
149
  You are a helpful assistant responding to questions about data in the tables provided.
150
 
 
222
 
223
  """
224
 
225
+ PROMPTS["naive_rag_response"] = """You're a helpful assistant
 
 
226
  Below are the knowledge you know:
227
  {content_data}
228
  ---
lightrag/storage.py CHANGED
@@ -1,16 +1,11 @@
1
  import asyncio
2
  import html
3
- import json
4
  import os
5
- from collections import defaultdict
6
- from dataclasses import dataclass, field
7
  from typing import Any, Union, cast
8
- import pickle
9
- import hnswlib
10
  import networkx as nx
11
  import numpy as np
12
  from nano_vectordb import NanoVectorDB
13
- import xxhash
14
 
15
  from .utils import load_json, logger, write_json
16
  from .base import (
@@ -19,6 +14,7 @@ from .base import (
19
  BaseVectorStorage,
20
  )
21
 
 
22
  @dataclass
23
  class JsonKVStorage(BaseKVStorage):
24
  def __post_init__(self):
@@ -59,12 +55,12 @@ class JsonKVStorage(BaseKVStorage):
59
  async def drop(self):
60
  self._data = {}
61
 
 
62
  @dataclass
63
  class NanoVectorDBStorage(BaseVectorStorage):
64
  cosine_better_than_threshold: float = 0.2
65
 
66
  def __post_init__(self):
67
-
68
  self._client_file_name = os.path.join(
69
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
70
  )
@@ -118,6 +114,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
118
  async def index_done_callback(self):
119
  self._client.save()
120
 
 
121
  @dataclass
122
  class NetworkXStorage(BaseGraphStorage):
123
  @staticmethod
@@ -142,7 +139,9 @@ class NetworkXStorage(BaseGraphStorage):
142
 
143
  graph = graph.copy()
144
  graph = cast(nx.Graph, largest_connected_component(graph))
145
- node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
 
 
146
  graph = nx.relabel_nodes(graph, node_mapping)
147
  return NetworkXStorage._stabilize_graph(graph)
148
 
 
1
  import asyncio
2
  import html
 
3
  import os
4
+ from dataclasses import dataclass
 
5
  from typing import Any, Union, cast
 
 
6
  import networkx as nx
7
  import numpy as np
8
  from nano_vectordb import NanoVectorDB
 
9
 
10
  from .utils import load_json, logger, write_json
11
  from .base import (
 
14
  BaseVectorStorage,
15
  )
16
 
17
+
18
  @dataclass
19
  class JsonKVStorage(BaseKVStorage):
20
  def __post_init__(self):
 
55
  async def drop(self):
56
  self._data = {}
57
 
58
+
59
  @dataclass
60
  class NanoVectorDBStorage(BaseVectorStorage):
61
  cosine_better_than_threshold: float = 0.2
62
 
63
  def __post_init__(self):
 
64
  self._client_file_name = os.path.join(
65
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
66
  )
 
114
  async def index_done_callback(self):
115
  self._client.save()
116
 
117
+
118
  @dataclass
119
  class NetworkXStorage(BaseGraphStorage):
120
  @staticmethod
 
139
 
140
  graph = graph.copy()
141
  graph = cast(nx.Graph, largest_connected_component(graph))
142
+ node_mapping = {
143
+ node: html.unescape(node.upper().strip()) for node in graph.nodes()
144
+ } # type: ignore
145
  graph = nx.relabel_nodes(graph, node_mapping)
146
  return NetworkXStorage._stabilize_graph(graph)
147
 
lightrag/utils.py CHANGED
@@ -16,18 +16,22 @@ ENCODER = None
16
 
17
  logger = logging.getLogger("lightrag")
18
 
 
19
  def set_logger(log_file: str):
20
  logger.setLevel(logging.DEBUG)
21
 
22
  file_handler = logging.FileHandler(log_file)
23
  file_handler.setLevel(logging.DEBUG)
24
 
25
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
 
 
26
  file_handler.setFormatter(formatter)
27
 
28
  if not logger.handlers:
29
  logger.addHandler(file_handler)
30
 
 
31
  @dataclass
32
  class EmbeddingFunc:
33
  embedding_dim: int
@@ -36,7 +40,8 @@ class EmbeddingFunc:
36
 
37
  async def __call__(self, *args, **kwargs) -> np.ndarray:
38
  return await self.func(*args, **kwargs)
39
-
 
40
  def locate_json_string_body_from_string(content: str) -> Union[str, None]:
41
  """Locate the JSON string body from a string"""
42
  maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
@@ -45,6 +50,7 @@ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
45
  else:
46
  return None
47
 
 
48
  def convert_response_to_json(response: str) -> dict:
49
  json_str = locate_json_string_body_from_string(response)
50
  assert json_str is not None, f"Unable to parse JSON from response: {response}"
@@ -55,12 +61,15 @@ def convert_response_to_json(response: str) -> dict:
55
  logger.error(f"Failed to parse JSON: {json_str}")
56
  raise e from None
57
 
 
58
  def compute_args_hash(*args):
59
  return md5(str(args).encode()).hexdigest()
60
 
 
61
  def compute_mdhash_id(content, prefix: str = ""):
62
  return prefix + md5(content.encode()).hexdigest()
63
 
 
64
  def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
65
  """Add restriction of maximum async calling times for a async func"""
66
 
@@ -82,6 +91,7 @@ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
82
 
83
  return final_decro
84
 
 
85
  def wrap_embedding_func_with_attrs(**kwargs):
86
  """Wrap a function with attributes"""
87
 
@@ -91,16 +101,19 @@ def wrap_embedding_func_with_attrs(**kwargs):
91
 
92
  return final_decro
93
 
 
94
  def load_json(file_name):
95
  if not os.path.exists(file_name):
96
  return None
97
  with open(file_name, encoding="utf-8") as f:
98
  return json.load(f)
99
 
 
100
  def write_json(json_obj, file_name):
101
  with open(file_name, "w", encoding="utf-8") as f:
102
  json.dump(json_obj, f, indent=2, ensure_ascii=False)
103
 
 
104
  def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
105
  global ENCODER
106
  if ENCODER is None:
@@ -116,12 +129,14 @@ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
116
  content = ENCODER.decode(tokens)
117
  return content
118
 
 
119
  def pack_user_ass_to_openai_messages(*args: str):
120
  roles = ["user", "assistant"]
121
  return [
122
  {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
123
  ]
124
 
 
125
  def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
126
  """Split a string by multiple markers"""
127
  if not markers:
@@ -129,6 +144,7 @@ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]
129
  results = re.split("|".join(re.escape(marker) for marker in markers), content)
130
  return [r.strip() for r in results if r.strip()]
131
 
 
132
  # Refer the utils functions of the official GraphRAG implementation:
133
  # https://github.com/microsoft/graphrag
134
  def clean_str(input: Any) -> str:
@@ -141,9 +157,11 @@ def clean_str(input: Any) -> str:
141
  # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
142
  return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
143
 
 
144
  def is_float_regex(value):
145
  return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
146
 
 
147
  def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
148
  """Truncate a list of data by token size"""
149
  if max_token_size <= 0:
@@ -155,11 +173,13 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
155
  return list_data[:i]
156
  return list_data
157
 
 
158
  def list_of_list_to_csv(data: list[list]):
159
  return "\n".join(
160
  [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
161
  )
162
 
 
163
  def save_data_to_file(data, file_name):
164
- with open(file_name, 'w', encoding='utf-8') as f:
165
- json.dump(data, f, ensure_ascii=False, indent=4)
 
16
 
17
  logger = logging.getLogger("lightrag")
18
 
19
+
20
  def set_logger(log_file: str):
21
  logger.setLevel(logging.DEBUG)
22
 
23
  file_handler = logging.FileHandler(log_file)
24
  file_handler.setLevel(logging.DEBUG)
25
 
26
+ formatter = logging.Formatter(
27
+ "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
28
+ )
29
  file_handler.setFormatter(formatter)
30
 
31
  if not logger.handlers:
32
  logger.addHandler(file_handler)
33
 
34
+
35
  @dataclass
36
  class EmbeddingFunc:
37
  embedding_dim: int
 
40
 
41
  async def __call__(self, *args, **kwargs) -> np.ndarray:
42
  return await self.func(*args, **kwargs)
43
+
44
+
45
  def locate_json_string_body_from_string(content: str) -> Union[str, None]:
46
  """Locate the JSON string body from a string"""
47
  maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
 
50
  else:
51
  return None
52
 
53
+
54
  def convert_response_to_json(response: str) -> dict:
55
  json_str = locate_json_string_body_from_string(response)
56
  assert json_str is not None, f"Unable to parse JSON from response: {response}"
 
61
  logger.error(f"Failed to parse JSON: {json_str}")
62
  raise e from None
63
 
64
+
65
  def compute_args_hash(*args):
66
  return md5(str(args).encode()).hexdigest()
67
 
68
+
69
  def compute_mdhash_id(content, prefix: str = ""):
70
  return prefix + md5(content.encode()).hexdigest()
71
 
72
+
73
  def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
74
  """Add restriction of maximum async calling times for a async func"""
75
 
 
91
 
92
  return final_decro
93
 
94
+
95
  def wrap_embedding_func_with_attrs(**kwargs):
96
  """Wrap a function with attributes"""
97
 
 
101
 
102
  return final_decro
103
 
104
+
105
  def load_json(file_name):
106
  if not os.path.exists(file_name):
107
  return None
108
  with open(file_name, encoding="utf-8") as f:
109
  return json.load(f)
110
 
111
+
112
  def write_json(json_obj, file_name):
113
  with open(file_name, "w", encoding="utf-8") as f:
114
  json.dump(json_obj, f, indent=2, ensure_ascii=False)
115
 
116
+
117
  def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
118
  global ENCODER
119
  if ENCODER is None:
 
129
  content = ENCODER.decode(tokens)
130
  return content
131
 
132
+
133
  def pack_user_ass_to_openai_messages(*args: str):
134
  roles = ["user", "assistant"]
135
  return [
136
  {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
137
  ]
138
 
139
+
140
  def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
141
  """Split a string by multiple markers"""
142
  if not markers:
 
144
  results = re.split("|".join(re.escape(marker) for marker in markers), content)
145
  return [r.strip() for r in results if r.strip()]
146
 
147
+
148
  # Refer the utils functions of the official GraphRAG implementation:
149
  # https://github.com/microsoft/graphrag
150
  def clean_str(input: Any) -> str:
 
157
  # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
158
  return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
159
 
160
+
161
  def is_float_regex(value):
162
  return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
163
 
164
+
165
  def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
166
  """Truncate a list of data by token size"""
167
  if max_token_size <= 0:
 
173
  return list_data[:i]
174
  return list_data
175
 
176
+
177
  def list_of_list_to_csv(data: list[list]):
178
  return "\n".join(
179
  [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
180
  )
181
 
182
+
183
  def save_data_to_file(data, file_name):
184
+ with open(file_name, "w", encoding="utf-8") as f:
185
+ json.dump(data, f, ensure_ascii=False, indent=4)
reproduce/Step_0.py CHANGED
@@ -3,11 +3,11 @@ import json
3
  import glob
4
  import argparse
5
 
6
- def extract_unique_contexts(input_directory, output_directory):
7
 
 
8
  os.makedirs(output_directory, exist_ok=True)
9
 
10
- jsonl_files = glob.glob(os.path.join(input_directory, '*.jsonl'))
11
  print(f"Found {len(jsonl_files)} JSONL files.")
12
 
13
  for file_path in jsonl_files:
@@ -21,18 +21,20 @@ def extract_unique_contexts(input_directory, output_directory):
21
  print(f"Processing file: {filename}")
22
 
23
  try:
24
- with open(file_path, 'r', encoding='utf-8') as infile:
25
  for line_number, line in enumerate(infile, start=1):
26
  line = line.strip()
27
  if not line:
28
  continue
29
  try:
30
  json_obj = json.loads(line)
31
- context = json_obj.get('context')
32
  if context and context not in unique_contexts_dict:
33
  unique_contexts_dict[context] = None
34
  except json.JSONDecodeError as e:
35
- print(f"JSON decoding error in file {filename} at line {line_number}: {e}")
 
 
36
  except FileNotFoundError:
37
  print(f"File not found: {filename}")
38
  continue
@@ -41,10 +43,12 @@ def extract_unique_contexts(input_directory, output_directory):
41
  continue
42
 
43
  unique_contexts_list = list(unique_contexts_dict.keys())
44
- print(f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}.")
 
 
45
 
46
  try:
47
- with open(output_path, 'w', encoding='utf-8') as outfile:
48
  json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
49
  print(f"Unique `context` entries have been saved to: {output_filename}")
50
  except Exception as e:
@@ -55,8 +59,10 @@ def extract_unique_contexts(input_directory, output_directory):
55
 
56
  if __name__ == "__main__":
57
  parser = argparse.ArgumentParser()
58
- parser.add_argument('-i', '--input_dir', type=str, default='../datasets')
59
- parser.add_argument('-o', '--output_dir', type=str, default='../datasets/unique_contexts')
 
 
60
 
61
  args = parser.parse_args()
62
 
 
3
  import glob
4
  import argparse
5
 
 
6
 
7
+ def extract_unique_contexts(input_directory, output_directory):
8
  os.makedirs(output_directory, exist_ok=True)
9
 
10
+ jsonl_files = glob.glob(os.path.join(input_directory, "*.jsonl"))
11
  print(f"Found {len(jsonl_files)} JSONL files.")
12
 
13
  for file_path in jsonl_files:
 
21
  print(f"Processing file: {filename}")
22
 
23
  try:
24
+ with open(file_path, "r", encoding="utf-8") as infile:
25
  for line_number, line in enumerate(infile, start=1):
26
  line = line.strip()
27
  if not line:
28
  continue
29
  try:
30
  json_obj = json.loads(line)
31
+ context = json_obj.get("context")
32
  if context and context not in unique_contexts_dict:
33
  unique_contexts_dict[context] = None
34
  except json.JSONDecodeError as e:
35
+ print(
36
+ f"JSON decoding error in file {filename} at line {line_number}: {e}"
37
+ )
38
  except FileNotFoundError:
39
  print(f"File not found: {filename}")
40
  continue
 
43
  continue
44
 
45
  unique_contexts_list = list(unique_contexts_dict.keys())
46
+ print(
47
+ f"There are {len(unique_contexts_list)} unique `context` entries in the file {filename}."
48
+ )
49
 
50
  try:
51
+ with open(output_path, "w", encoding="utf-8") as outfile:
52
  json.dump(unique_contexts_list, outfile, ensure_ascii=False, indent=4)
53
  print(f"Unique `context` entries have been saved to: {output_filename}")
54
  except Exception as e:
 
59
 
60
  if __name__ == "__main__":
61
  parser = argparse.ArgumentParser()
62
+ parser.add_argument("-i", "--input_dir", type=str, default="../datasets")
63
+ parser.add_argument(
64
+ "-o", "--output_dir", type=str, default="../datasets/unique_contexts"
65
+ )
66
 
67
  args = parser.parse_args()
68
 
reproduce/Step_1.py CHANGED
@@ -4,10 +4,11 @@ import time
4
 
5
  from lightrag import LightRAG
6
 
 
7
  def insert_text(rag, file_path):
8
- with open(file_path, mode='r') as f:
9
  unique_contexts = json.load(f)
10
-
11
  retries = 0
12
  max_retries = 3
13
  while retries < max_retries:
@@ -21,6 +22,7 @@ def insert_text(rag, file_path):
21
  if retries == max_retries:
22
  print("Insertion failed after exceeding the maximum number of retries")
23
 
 
24
  cls = "agriculture"
25
  WORKING_DIR = "../{cls}"
26
 
@@ -29,4 +31,4 @@ if not os.path.exists(WORKING_DIR):
29
 
30
  rag = LightRAG(working_dir=WORKING_DIR)
31
 
32
- insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
 
4
 
5
  from lightrag import LightRAG
6
 
7
+
8
  def insert_text(rag, file_path):
9
+ with open(file_path, mode="r") as f:
10
  unique_contexts = json.load(f)
11
+
12
  retries = 0
13
  max_retries = 3
14
  while retries < max_retries:
 
22
  if retries == max_retries:
23
  print("Insertion failed after exceeding the maximum number of retries")
24
 
25
+
26
  cls = "agriculture"
27
  WORKING_DIR = "../{cls}"
28
 
 
31
 
32
  rag = LightRAG(working_dir=WORKING_DIR)
33
 
34
+ insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
reproduce/Step_1_openai_compatible.py CHANGED
@@ -7,6 +7,7 @@ from lightrag import LightRAG
7
  from lightrag.utils import EmbeddingFunc
8
  from lightrag.llm import openai_complete_if_cache, openai_embedding
9
 
 
10
  ## For Upstage API
11
  # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
12
  async def llm_model_func(
@@ -19,22 +20,26 @@ async def llm_model_func(
19
  history_messages=history_messages,
20
  api_key=os.getenv("UPSTAGE_API_KEY"),
21
  base_url="https://api.upstage.ai/v1/solar",
22
- **kwargs
23
  )
24
 
 
25
  async def embedding_func(texts: list[str]) -> np.ndarray:
26
  return await openai_embedding(
27
  texts,
28
  model="solar-embedding-1-large-query",
29
  api_key=os.getenv("UPSTAGE_API_KEY"),
30
- base_url="https://api.upstage.ai/v1/solar"
31
  )
 
 
32
  ## /For Upstage API
33
 
 
34
  def insert_text(rag, file_path):
35
- with open(file_path, mode='r') as f:
36
  unique_contexts = json.load(f)
37
-
38
  retries = 0
39
  max_retries = 3
40
  while retries < max_retries:
@@ -48,19 +53,19 @@ def insert_text(rag, file_path):
48
  if retries == max_retries:
49
  print("Insertion failed after exceeding the maximum number of retries")
50
 
 
51
  cls = "mix"
52
  WORKING_DIR = f"../{cls}"
53
 
54
  if not os.path.exists(WORKING_DIR):
55
  os.mkdir(WORKING_DIR)
56
 
57
- rag = LightRAG(working_dir=WORKING_DIR,
58
- llm_model_func=llm_model_func,
59
- embedding_func=EmbeddingFunc(
60
- embedding_dim=4096,
61
- max_token_size=8192,
62
- func=embedding_func
63
- )
64
- )
65
 
66
  insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
 
7
  from lightrag.utils import EmbeddingFunc
8
  from lightrag.llm import openai_complete_if_cache, openai_embedding
9
 
10
+
11
  ## For Upstage API
12
  # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
13
  async def llm_model_func(
 
20
  history_messages=history_messages,
21
  api_key=os.getenv("UPSTAGE_API_KEY"),
22
  base_url="https://api.upstage.ai/v1/solar",
23
+ **kwargs,
24
  )
25
 
26
+
27
  async def embedding_func(texts: list[str]) -> np.ndarray:
28
  return await openai_embedding(
29
  texts,
30
  model="solar-embedding-1-large-query",
31
  api_key=os.getenv("UPSTAGE_API_KEY"),
32
+ base_url="https://api.upstage.ai/v1/solar",
33
  )
34
+
35
+
36
  ## /For Upstage API
37
 
38
+
39
  def insert_text(rag, file_path):
40
+ with open(file_path, mode="r") as f:
41
  unique_contexts = json.load(f)
42
+
43
  retries = 0
44
  max_retries = 3
45
  while retries < max_retries:
 
53
  if retries == max_retries:
54
  print("Insertion failed after exceeding the maximum number of retries")
55
 
56
+
57
  cls = "mix"
58
  WORKING_DIR = f"../{cls}"
59
 
60
  if not os.path.exists(WORKING_DIR):
61
  os.mkdir(WORKING_DIR)
62
 
63
+ rag = LightRAG(
64
+ working_dir=WORKING_DIR,
65
+ llm_model_func=llm_model_func,
66
+ embedding_func=EmbeddingFunc(
67
+ embedding_dim=4096, max_token_size=8192, func=embedding_func
68
+ ),
69
+ )
 
70
 
71
  insert_text(rag, f"../datasets/unique_contexts/{cls}_unique_contexts.json")
reproduce/Step_2.py CHANGED
@@ -1,8 +1,8 @@
1
- import os
2
  import json
3
  from openai import OpenAI
4
  from transformers import GPT2Tokenizer
5
 
 
6
  def openai_complete_if_cache(
7
  model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
8
  ) -> str:
@@ -19,24 +19,26 @@ def openai_complete_if_cache(
19
  )
20
  return response.choices[0].message.content
21
 
22
- tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
 
 
23
 
24
  def get_summary(context, tot_tokens=2000):
25
  tokens = tokenizer.tokenize(context)
26
  half_tokens = tot_tokens // 2
27
 
28
- start_tokens = tokens[1000:1000 + half_tokens]
29
- end_tokens = tokens[-(1000 + half_tokens):1000]
30
 
31
  summary_tokens = start_tokens + end_tokens
32
  summary = tokenizer.convert_tokens_to_string(summary_tokens)
33
-
34
  return summary
35
 
36
 
37
- clses = ['agriculture']
38
  for cls in clses:
39
- with open(f'../datasets/unique_contexts/{cls}_unique_contexts.json', mode='r') as f:
40
  unique_contexts = json.load(f)
41
 
42
  summaries = [get_summary(context) for context in unique_contexts]
@@ -67,10 +69,10 @@ for cls in clses:
67
  ...
68
  """
69
 
70
- result = openai_complete_if_cache(model='gpt-4o', prompt=prompt)
71
 
72
  file_path = f"../datasets/questions/{cls}_questions.txt"
73
  with open(file_path, "w") as file:
74
  file.write(result)
75
 
76
- print(f"{cls}_questions written to {file_path}")
 
 
1
  import json
2
  from openai import OpenAI
3
  from transformers import GPT2Tokenizer
4
 
5
+
6
  def openai_complete_if_cache(
7
  model="gpt-4o", prompt=None, system_prompt=None, history_messages=[], **kwargs
8
  ) -> str:
 
19
  )
20
  return response.choices[0].message.content
21
 
22
+
23
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
24
+
25
 
26
  def get_summary(context, tot_tokens=2000):
27
  tokens = tokenizer.tokenize(context)
28
  half_tokens = tot_tokens // 2
29
 
30
+ start_tokens = tokens[1000 : 1000 + half_tokens]
31
+ end_tokens = tokens[-(1000 + half_tokens) : 1000]
32
 
33
  summary_tokens = start_tokens + end_tokens
34
  summary = tokenizer.convert_tokens_to_string(summary_tokens)
35
+
36
  return summary
37
 
38
 
39
+ clses = ["agriculture"]
40
  for cls in clses:
41
+ with open(f"../datasets/unique_contexts/{cls}_unique_contexts.json", mode="r") as f:
42
  unique_contexts = json.load(f)
43
 
44
  summaries = [get_summary(context) for context in unique_contexts]
 
69
  ...
70
  """
71
 
72
+ result = openai_complete_if_cache(model="gpt-4o", prompt=prompt)
73
 
74
  file_path = f"../datasets/questions/{cls}_questions.txt"
75
  with open(file_path, "w") as file:
76
  file.write(result)
77
 
78
+ print(f"{cls}_questions written to {file_path}")
reproduce/Step_3.py CHANGED
@@ -4,16 +4,18 @@ import asyncio
4
  from lightrag import LightRAG, QueryParam
5
  from tqdm import tqdm
6
 
 
7
  def extract_queries(file_path):
8
- with open(file_path, 'r') as f:
9
  data = f.read()
10
-
11
- data = data.replace('**', '')
12
 
13
- queries = re.findall(r'- Question \d+: (.+)', data)
 
 
14
 
15
  return queries
16
 
 
17
  async def process_query(query_text, rag_instance, query_param):
18
  try:
19
  result, context = await rag_instance.aquery(query_text, param=query_param)
@@ -21,6 +23,7 @@ async def process_query(query_text, rag_instance, query_param):
21
  except Exception as e:
22
  return None, {"query": query_text, "error": str(e)}
23
 
 
24
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
25
  try:
26
  loop = asyncio.get_event_loop()
@@ -29,15 +32,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
29
  asyncio.set_event_loop(loop)
30
  return loop
31
 
32
- def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
 
 
 
33
  loop = always_get_an_event_loop()
34
 
35
- with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
 
 
36
  result_file.write("[\n")
37
  first_entry = True
38
 
39
  for query_text in tqdm(queries, desc="Processing queries", unit="query"):
40
- result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
 
 
41
 
42
  if result:
43
  if not first_entry:
@@ -50,6 +60,7 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
50
 
51
  result_file.write("\n]")
52
 
 
53
  if __name__ == "__main__":
54
  cls = "agriculture"
55
  mode = "hybrid"
@@ -59,4 +70,6 @@ if __name__ == "__main__":
59
  query_param = QueryParam(mode=mode)
60
 
61
  queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
62
- run_queries_and_save_to_json(queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json")
 
 
 
4
  from lightrag import LightRAG, QueryParam
5
  from tqdm import tqdm
6
 
7
+
8
  def extract_queries(file_path):
9
+ with open(file_path, "r") as f:
10
  data = f.read()
 
 
11
 
12
+ data = data.replace("**", "")
13
+
14
+ queries = re.findall(r"- Question \d+: (.+)", data)
15
 
16
  return queries
17
 
18
+
19
  async def process_query(query_text, rag_instance, query_param):
20
  try:
21
  result, context = await rag_instance.aquery(query_text, param=query_param)
 
23
  except Exception as e:
24
  return None, {"query": query_text, "error": str(e)}
25
 
26
+
27
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
28
  try:
29
  loop = asyncio.get_event_loop()
 
32
  asyncio.set_event_loop(loop)
33
  return loop
34
 
35
+
36
+ def run_queries_and_save_to_json(
37
+ queries, rag_instance, query_param, output_file, error_file
38
+ ):
39
  loop = always_get_an_event_loop()
40
 
41
+ with open(output_file, "a", encoding="utf-8") as result_file, open(
42
+ error_file, "a", encoding="utf-8"
43
+ ) as err_file:
44
  result_file.write("[\n")
45
  first_entry = True
46
 
47
  for query_text in tqdm(queries, desc="Processing queries", unit="query"):
48
+ result, error = loop.run_until_complete(
49
+ process_query(query_text, rag_instance, query_param)
50
+ )
51
 
52
  if result:
53
  if not first_entry:
 
60
 
61
  result_file.write("\n]")
62
 
63
+
64
  if __name__ == "__main__":
65
  cls = "agriculture"
66
  mode = "hybrid"
 
70
  query_param = QueryParam(mode=mode)
71
 
72
  queries = extract_queries(f"../datasets/questions/{cls}_questions.txt")
73
+ run_queries_and_save_to_json(
74
+ queries, rag, query_param, f"{cls}_result.json", f"{cls}_errors.json"
75
+ )
reproduce/Step_3_openai_compatible.py CHANGED
@@ -8,6 +8,7 @@ from lightrag.llm import openai_complete_if_cache, openai_embedding
8
  from lightrag.utils import EmbeddingFunc
9
  import numpy as np
10
 
 
11
  ## For Upstage API
12
  # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
13
  async def llm_model_func(
@@ -20,28 +21,33 @@ async def llm_model_func(
20
  history_messages=history_messages,
21
  api_key=os.getenv("UPSTAGE_API_KEY"),
22
  base_url="https://api.upstage.ai/v1/solar",
23
- **kwargs
24
  )
25
 
 
26
  async def embedding_func(texts: list[str]) -> np.ndarray:
27
  return await openai_embedding(
28
  texts,
29
  model="solar-embedding-1-large-query",
30
  api_key=os.getenv("UPSTAGE_API_KEY"),
31
- base_url="https://api.upstage.ai/v1/solar"
32
  )
 
 
33
  ## /For Upstage API
34
 
 
35
  def extract_queries(file_path):
36
- with open(file_path, 'r') as f:
37
  data = f.read()
38
-
39
- data = data.replace('**', '')
40
 
41
- queries = re.findall(r'- Question \d+: (.+)', data)
 
 
42
 
43
  return queries
44
 
 
45
  async def process_query(query_text, rag_instance, query_param):
46
  try:
47
  result, context = await rag_instance.aquery(query_text, param=query_param)
@@ -49,6 +55,7 @@ async def process_query(query_text, rag_instance, query_param):
49
  except Exception as e:
50
  return None, {"query": query_text, "error": str(e)}
51
 
 
52
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
53
  try:
54
  loop = asyncio.get_event_loop()
@@ -57,15 +64,22 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
57
  asyncio.set_event_loop(loop)
58
  return loop
59
 
60
- def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file, error_file):
 
 
 
61
  loop = always_get_an_event_loop()
62
 
63
- with open(output_file, 'a', encoding='utf-8') as result_file, open(error_file, 'a', encoding='utf-8') as err_file:
 
 
64
  result_file.write("[\n")
65
  first_entry = True
66
 
67
  for query_text in tqdm(queries, desc="Processing queries", unit="query"):
68
- result, error = loop.run_until_complete(process_query(query_text, rag_instance, query_param))
 
 
69
 
70
  if result:
71
  if not first_entry:
@@ -78,22 +92,24 @@ def run_queries_and_save_to_json(queries, rag_instance, query_param, output_file
78
 
79
  result_file.write("\n]")
80
 
 
81
  if __name__ == "__main__":
82
  cls = "mix"
83
  mode = "hybrid"
84
  WORKING_DIR = f"../{cls}"
85
 
86
  rag = LightRAG(working_dir=WORKING_DIR)
87
- rag = LightRAG(working_dir=WORKING_DIR,
88
- llm_model_func=llm_model_func,
89
- embedding_func=EmbeddingFunc(
90
- embedding_dim=4096,
91
- max_token_size=8192,
92
- func=embedding_func
93
- )
94
- )
95
  query_param = QueryParam(mode=mode)
96
 
97
- base_dir='../datasets/questions'
98
  queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
99
- run_queries_and_save_to_json(queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json")
 
 
 
8
  from lightrag.utils import EmbeddingFunc
9
  import numpy as np
10
 
11
+
12
  ## For Upstage API
13
  # please check if embedding_dim=4096 in lightrag.py and llm.py in lightrag direcotry
14
  async def llm_model_func(
 
21
  history_messages=history_messages,
22
  api_key=os.getenv("UPSTAGE_API_KEY"),
23
  base_url="https://api.upstage.ai/v1/solar",
24
+ **kwargs,
25
  )
26
 
27
+
28
  async def embedding_func(texts: list[str]) -> np.ndarray:
29
  return await openai_embedding(
30
  texts,
31
  model="solar-embedding-1-large-query",
32
  api_key=os.getenv("UPSTAGE_API_KEY"),
33
+ base_url="https://api.upstage.ai/v1/solar",
34
  )
35
+
36
+
37
  ## /For Upstage API
38
 
39
+
40
  def extract_queries(file_path):
41
+ with open(file_path, "r") as f:
42
  data = f.read()
 
 
43
 
44
+ data = data.replace("**", "")
45
+
46
+ queries = re.findall(r"- Question \d+: (.+)", data)
47
 
48
  return queries
49
 
50
+
51
  async def process_query(query_text, rag_instance, query_param):
52
  try:
53
  result, context = await rag_instance.aquery(query_text, param=query_param)
 
55
  except Exception as e:
56
  return None, {"query": query_text, "error": str(e)}
57
 
58
+
59
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
60
  try:
61
  loop = asyncio.get_event_loop()
 
64
  asyncio.set_event_loop(loop)
65
  return loop
66
 
67
+
68
+ def run_queries_and_save_to_json(
69
+ queries, rag_instance, query_param, output_file, error_file
70
+ ):
71
  loop = always_get_an_event_loop()
72
 
73
+ with open(output_file, "a", encoding="utf-8") as result_file, open(
74
+ error_file, "a", encoding="utf-8"
75
+ ) as err_file:
76
  result_file.write("[\n")
77
  first_entry = True
78
 
79
  for query_text in tqdm(queries, desc="Processing queries", unit="query"):
80
+ result, error = loop.run_until_complete(
81
+ process_query(query_text, rag_instance, query_param)
82
+ )
83
 
84
  if result:
85
  if not first_entry:
 
92
 
93
  result_file.write("\n]")
94
 
95
+
96
  if __name__ == "__main__":
97
  cls = "mix"
98
  mode = "hybrid"
99
  WORKING_DIR = f"../{cls}"
100
 
101
  rag = LightRAG(working_dir=WORKING_DIR)
102
+ rag = LightRAG(
103
+ working_dir=WORKING_DIR,
104
+ llm_model_func=llm_model_func,
105
+ embedding_func=EmbeddingFunc(
106
+ embedding_dim=4096, max_token_size=8192, func=embedding_func
107
+ ),
108
+ )
 
109
  query_param = QueryParam(mode=mode)
110
 
111
+ base_dir = "../datasets/questions"
112
  queries = extract_queries(f"{base_dir}/{cls}_questions.txt")
113
+ run_queries_and_save_to_json(
114
+ queries, rag, query_param, f"{base_dir}/result.json", f"{base_dir}/errors.json"
115
+ )
requirements.txt CHANGED
@@ -1,13 +1,13 @@
 
1
  aioboto3
2
- openai
3
- tiktoken
4
- networkx
5
  graspologic
6
- nano-vectordb
7
  hnswlib
8
- xxhash
 
 
 
9
  tenacity
10
- transformers
11
  torch
12
- ollama
13
- accelerate
 
1
+ accelerate
2
  aioboto3
 
 
 
3
  graspologic
 
4
  hnswlib
5
+ nano-vectordb
6
+ networkx
7
+ ollama
8
+ openai
9
  tenacity
10
+ tiktoken
11
  torch
12
+ transformers
13
+ xxhash