LarFii commited on
Commit
c6de7de
·
1 Parent(s): 4cfd55e
README.md ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LightRAG: Simple and Fast Retrieval-Augmented Generation
2
+ ![请添加图片描述](https://i-blog.csdnimg.cn/direct/567139f1a36e4564abc63ce5c12b6271.jpeg)
3
+
4
+
5
+
6
+ <a href='https://github.com/HKUDS/LightRAG'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
7
+ <a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
8
+
9
+ This repository hosts the code of LightRAG. The structure of this code is based on [nano-graphrag](https://github.com/gusye1234/nano-graphrag).
10
+ ![请添加图片描述](https://i-blog.csdnimg.cn/direct/b2aaf634151b4706892693ffb43d9093.png)
11
+ ## Install
12
+
13
+ * Install from source
14
+
15
+ ```bash
16
+ cd LightRAG
17
+ pip install -e .
18
+ ```
19
+ * Install from PyPI
20
+ ```bash
21
+ pip install lightrag-hku
22
+ ```
23
+
24
+ ## Quick Start
25
+
26
+ * Set OpenAI API key in environment: `export OPENAI_API_KEY="sk-...".`
27
+ * Download the demo text "A Christmas Carol by Charles Dickens"
28
+ ```bash
29
+ curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_data.txt > ./book.txt
30
+ ```
31
+ Use the below python snippet:
32
+
33
+ ```python
34
+ from lightrag import LightRAG, QueryParam
35
+
36
+ rag = LightRAG(working_dir="./dickens")
37
+
38
+ with open("./book.txt") as f:
39
+ rag.insert(f.read())
40
+
41
+ # Perform naive search
42
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
43
+
44
+ # Perform local search
45
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
46
+
47
+ # Perform global search
48
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
49
+
50
+ # Perform hybird search
51
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybird")))
52
+ ```
53
+ Batch Insert
54
+ ```python
55
+ rag.insert(["TEXT1", "TEXT2",...])
56
+ ```
57
+ Incremental Insert
58
+
59
+ ```python
60
+ rag = LightRAG(working_dir="./dickens")
61
+
62
+ with open("./newText.txt") as f:
63
+ rag.insert(f.read())
64
+ ```
65
+ ## Evaluation
66
+ ### Dataset
67
+ The dataset used in LightRAG can be download from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
68
+
69
+ ### Generate Query
70
+ LightRAG uses the following prompt to generate high-level queries, with the corresponding code located in `example/generate_query.py`.
71
+ ```python
72
+ Given the following description of a dataset:
73
+
74
+ {description}
75
+
76
+ Please identify 5 potential users who would engage with this dataset. For each user, list 5 tasks they would perform with this dataset. Then, for each (user, task) combination, generate 5 questions that require a high-level understanding of the entire dataset.
77
+
78
+ Output the results in the following structure:
79
+ - User 1: [user description]
80
+ - Task 1: [task description]
81
+ - Question 1:
82
+ - Question 2:
83
+ - Question 3:
84
+ - Question 4:
85
+ - Question 5:
86
+ - Task 2: [task description]
87
+ ...
88
+ - Task 5: [task description]
89
+ - User 2: [user description]
90
+ ...
91
+ - User 5: [user description]
92
+ ...
93
+ ```
94
+
95
+ ### Batch Eval
96
+ To evaluate the performance of two RAG systems on high-level queries, LightRAG uses the following prompt, with the specific code available in `example/batch_eval.py`.
97
+ ```python
98
+ ---Role---
99
+ You are an expert tasked with evaluating two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
100
+ ---Goal---
101
+ You will evaluate two answers to the same question based on three criteria: **Comprehensiveness**, **Diversity**, and **Empowerment**.
102
+
103
+ - **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
104
+ - **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
105
+ - **Empowerment**: How well does the answer help the reader understand and make informed judgments about the topic?
106
+
107
+ For each criterion, choose the better answer (either Answer 1 or Answer 2) and explain why. Then, select an overall winner based on these three categories.
108
+
109
+ Here is the question:
110
+ {query}
111
+
112
+ Here are the two answers:
113
+
114
+ **Answer 1:**
115
+ {answer1}
116
+
117
+ **Answer 2:**
118
+ {answer2}
119
+
120
+ Evaluate both answers using the three criteria listed above and provide detailed explanations for each criterion.
121
+
122
+ Output your evaluation in the following JSON format:
123
+
124
+ {{
125
+ "Comprehensiveness": {{
126
+ "Winner": "[Answer 1 or Answer 2]",
127
+ "Explanation": "[Provide explanation here]"
128
+ }},
129
+ "Empowerment": {{
130
+ "Winner": "[Answer 1 or Answer 2]",
131
+ "Explanation": "[Provide explanation here]"
132
+ }},
133
+ "Overall Winner": {{
134
+ "Winner": "[Answer 1 or Answer 2]",
135
+ "Explanation": "[Summarize why this answer is the overall winner based on the three criteria]"
136
+ }}
137
+ }}
138
+ ```
139
+ ### Overall Performance Table
140
+ ### Overall Performance Table
141
+ | | **Agriculture** | | **CS** | | **Legal** | | **Mix** | |
142
+ |----------------------|-------------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|-----------------------|
143
+ | | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** | NaiveRAG | **LightRAG** |
144
+ | **Comprehensiveness** | 32.69% | **67.31%** | 35.44% | **64.56%** | 19.05% | **80.95%** | 36.36% | **63.64%** |
145
+ | **Diversity** | 24.09% | **75.91%** | 35.24% | **64.76%** | 10.98% | **89.02%** | 30.76% | **69.24%** |
146
+ | **Empowerment** | 31.35% | **68.65%** | 35.48% | **64.52%** | 17.59% | **82.41%** | 40.95% | **59.05%** |
147
+ | **Overall** | 33.30% | **66.70%** | 34.76% | **65.24%** | 17.46% | **82.54%** | 37.59% | **62.40%** |
148
+ | | RQ-RAG | **LightRAG** | RQ-RAG | **LightRAG** | RQ-RAG | **LightRAG** | RQ-RAG | **LightRAG** |
149
+ | **Comprehensiveness** | 32.05% | **67.95%** | 39.30% | **60.70%** | 18.57% | **81.43%** | 38.89% | **61.11%** |
150
+ | **Diversity** | 29.44% | **70.56%** | 38.71% | **61.29%** | 15.14% | **84.86%** | 28.50% | **71.50%** |
151
+ | **Empowerment** | 32.51% | **67.49%** | 37.52% | **62.48%** | 17.80% | **82.20%** | 43.96% | **56.04%** |
152
+ | **Overall** | 33.29% | **66.71%** | 39.03% | **60.97%** | 17.80% | **82.20%** | 39.61% | **60.39%** |
153
+ | | HyDE | **LightRAG** | HyDE | **LightRAG** | HyDE | **LightRAG** | HyDE | **LightRAG** |
154
+ | **Comprehensiveness** | 24.39% | **75.61%** | 36.49% | **63.51%** | 27.68% | **72.32%** | 42.17% | **57.83%** |
155
+ | **Diversity** | 24.96% | **75.34%** | 37.41% | **62.59%** | 18.79% | **81.21%** | 30.88% | **69.12%** |
156
+ | **Empowerment** | 24.89% | **75.11%** | 34.99% | **65.01%** | 26.99% | **73.01%** | **45.61%** | **54.39%** |
157
+ | **Overall** | 23.17% | **76.83%** | 35.67% | **64.33%** | 27.68% | **72.32%** | 42.72% | **57.28%** |
158
+ | | GraphRAG | **LightRAG** | GraphRAG | **LightRAG** | GraphRAG | **LightRAG** | GraphRAG | **LightRAG** |
159
+ | **Comprehensiveness** | 45.56% | **54.44%** | 45.98% | **54.02%** | 47.13% | **52.87%** | **51.86%** | 48.14% |
160
+ | **Diversity** | 19.65% | **80.35%** | 39.64% | **60.36%** | 25.55% | **74.45%** | 35.87% | **64.13%** |
161
+ | **Empowerment** | 36.69% | **63.31%** | 45.09% | **54.91%** | 42.81% | **57.19%** | **52.94%** | 47.06% |
162
+ | **Overall** | 43.62% | **56.38%** | 45.98% | **54.02%** | 45.70% | **54.30%** | **51.86%** | 48.14% |
163
+
164
+ ## Code Structure
165
+
166
+ ```python
167
+ .
168
+ ├── examples
169
+ │ ├── batch_eval.py
170
+ │ ├── generate_query.py
171
+ │ ├── insert.py
172
+ │ └── query.py
173
+ ├── lightrag
174
+ │ ├── __init__.py
175
+ │ ├── base.py
176
+ │ ├── lightrag.py
177
+ │ ├── llm.py
178
+ │ ├── operate.py
179
+ │ ├── prompt.py
180
+ │ ├── storage.py
181
+ │ └── utils.jpeg
182
+ ├── LICENSE
183
+ ├── README.md
184
+ ├── requirements.txt
185
+ └── setup.py
186
+ ```
187
+ ## Citation
188
+
189
+ ```
190
+ @article{guo2024lightrag,
191
+ title={LightRAG: Simple and Fast Retrieval-Augmented Generation},
192
+ author={Zirui Guo and Lianghao Xia and Yanhua Yu and Tu Ao and Chao Huang},
193
+ year={2024},
194
+ eprint={2410.05779},
195
+ archivePrefix={arXiv},
196
+ primaryClass={cs.IR}
197
+ }
198
+ ```
lightrag/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .lightrag import LightRAG, QueryParam
2
+
3
+ __version__ = "0.0.2"
4
+ __author__ = "Zirui Guo"
5
+ __url__ = "https://github.com/HKUDS/GraphEdit"
lightrag/base.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import TypedDict, Union, Literal, Generic, TypeVar
3
+
4
+ import numpy as np
5
+
6
+ from .utils import EmbeddingFunc
7
+
8
+ TextChunkSchema = TypedDict(
9
+ "TextChunkSchema",
10
+ {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
11
+ )
12
+
13
+ T = TypeVar("T")
14
+
15
+ @dataclass
16
+ class QueryParam:
17
+ mode: Literal["local", "global", "hybird", "naive"] = "global"
18
+ only_need_context: bool = False
19
+ response_type: str = "Multiple Paragraphs"
20
+ top_k: int = 60
21
+ max_token_for_text_unit: int = 4000
22
+ max_token_for_global_context: int = 4000
23
+ max_token_for_local_context: int = 4000
24
+
25
+
26
+ @dataclass
27
+ class StorageNameSpace:
28
+ namespace: str
29
+ global_config: dict
30
+
31
+ async def index_done_callback(self):
32
+ """commit the storage operations after indexing"""
33
+ pass
34
+
35
+ async def query_done_callback(self):
36
+ """commit the storage operations after querying"""
37
+ pass
38
+
39
+ @dataclass
40
+ class BaseVectorStorage(StorageNameSpace):
41
+ embedding_func: EmbeddingFunc
42
+ meta_fields: set = field(default_factory=set)
43
+
44
+ async def query(self, query: str, top_k: int) -> list[dict]:
45
+ raise NotImplementedError
46
+
47
+ async def upsert(self, data: dict[str, dict]):
48
+ """Use 'content' field from value for embedding, use key as id.
49
+ If embedding_func is None, use 'embedding' field from value
50
+ """
51
+ raise NotImplementedError
52
+
53
+ @dataclass
54
+ class BaseKVStorage(Generic[T], StorageNameSpace):
55
+ async def all_keys(self) -> list[str]:
56
+ raise NotImplementedError
57
+
58
+ async def get_by_id(self, id: str) -> Union[T, None]:
59
+ raise NotImplementedError
60
+
61
+ async def get_by_ids(
62
+ self, ids: list[str], fields: Union[set[str], None] = None
63
+ ) -> list[Union[T, None]]:
64
+ raise NotImplementedError
65
+
66
+ async def filter_keys(self, data: list[str]) -> set[str]:
67
+ """return un-exist keys"""
68
+ raise NotImplementedError
69
+
70
+ async def upsert(self, data: dict[str, T]):
71
+ raise NotImplementedError
72
+
73
+ async def drop(self):
74
+ raise NotImplementedError
75
+
76
+
77
+ @dataclass
78
+ class BaseGraphStorage(StorageNameSpace):
79
+ async def has_node(self, node_id: str) -> bool:
80
+ raise NotImplementedError
81
+
82
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
83
+ raise NotImplementedError
84
+
85
+ async def node_degree(self, node_id: str) -> int:
86
+ raise NotImplementedError
87
+
88
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
89
+ raise NotImplementedError
90
+
91
+ async def get_node(self, node_id: str) -> Union[dict, None]:
92
+ raise NotImplementedError
93
+
94
+ async def get_edge(
95
+ self, source_node_id: str, target_node_id: str
96
+ ) -> Union[dict, None]:
97
+ raise NotImplementedError
98
+
99
+ async def get_node_edges(
100
+ self, source_node_id: str
101
+ ) -> Union[list[tuple[str, str]], None]:
102
+ raise NotImplementedError
103
+
104
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
105
+ raise NotImplementedError
106
+
107
+ async def upsert_edge(
108
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
109
+ ):
110
+ raise NotImplementedError
111
+
112
+ async def clustering(self, algorithm: str):
113
+ raise NotImplementedError
114
+
115
+ async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
116
+ raise NotImplementedError("Node embedding is not used in lightrag.")
lightrag/lightrag.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ from dataclasses import asdict, dataclass, field
4
+ from datetime import datetime
5
+ from functools import partial
6
+ from typing import Type, cast
7
+
8
+ from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding
9
+ from .operate import (
10
+ chunking_by_token_size,
11
+ extract_entities,
12
+ local_query,
13
+ global_query,
14
+ hybird_query,
15
+ naive_query,
16
+ )
17
+
18
+ from .storage import (
19
+ JsonKVStorage,
20
+ NanoVectorDBStorage,
21
+ NetworkXStorage,
22
+ )
23
+ from .utils import (
24
+ EmbeddingFunc,
25
+ compute_mdhash_id,
26
+ limit_async_func_call,
27
+ convert_response_to_json,
28
+ logger,
29
+ set_logger,
30
+ )
31
+ from .base import (
32
+ BaseGraphStorage,
33
+ BaseKVStorage,
34
+ BaseVectorStorage,
35
+ StorageNameSpace,
36
+ QueryParam,
37
+ )
38
+
39
+ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
40
+ try:
41
+ # If there is already an event loop, use it.
42
+ loop = asyncio.get_event_loop()
43
+ except RuntimeError:
44
+ # If in a sub-thread, create a new event loop.
45
+ logger.info("Creating a new event loop in a sub-thread.")
46
+ loop = asyncio.new_event_loop()
47
+ asyncio.set_event_loop(loop)
48
+ return loop
49
+
50
+ @dataclass
51
+ class LightRAG:
52
+ working_dir: str = field(
53
+ default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
54
+ )
55
+
56
+ # text chunking
57
+ chunk_token_size: int = 1200
58
+ chunk_overlap_token_size: int = 100
59
+ tiktoken_model_name: str = "gpt-4o-mini"
60
+
61
+ # entity extraction
62
+ entity_extract_max_gleaning: int = 1
63
+ entity_summary_to_max_tokens: int = 500
64
+
65
+ # node embedding
66
+ node_embedding_algorithm: str = "node2vec"
67
+ node2vec_params: dict = field(
68
+ default_factory=lambda: {
69
+ "dimensions": 1536,
70
+ "num_walks": 10,
71
+ "walk_length": 40,
72
+ "num_walks": 10,
73
+ "window_size": 2,
74
+ "iterations": 3,
75
+ "random_seed": 3,
76
+ }
77
+ )
78
+
79
+ # text embedding
80
+ embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
81
+ embedding_batch_num: int = 32
82
+ embedding_func_max_async: int = 16
83
+
84
+ # LLM
85
+ llm_model_func: callable = gpt_4o_mini_complete
86
+ llm_model_max_token_size: int = 32768
87
+ llm_model_max_async: int = 16
88
+
89
+ # storage
90
+ key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
91
+ vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
92
+ vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
93
+ graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
94
+ enable_llm_cache: bool = True
95
+
96
+ # extension
97
+ addon_params: dict = field(default_factory=dict)
98
+ convert_response_to_json_func: callable = convert_response_to_json
99
+
100
+ def __post_init__(self):
101
+ log_file = os.path.join(self.working_dir, "lightrag.log")
102
+ set_logger(log_file)
103
+ logger.info(f"Logger initialized for working directory: {self.working_dir}")
104
+
105
+ _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
106
+ logger.debug(f"LightRAG init with param:\n {_print_config}\n")
107
+
108
+ if not os.path.exists(self.working_dir):
109
+ logger.info(f"Creating working directory {self.working_dir}")
110
+ os.makedirs(self.working_dir)
111
+
112
+ self.full_docs = self.key_string_value_json_storage_cls(
113
+ namespace="full_docs", global_config=asdict(self)
114
+ )
115
+
116
+ self.text_chunks = self.key_string_value_json_storage_cls(
117
+ namespace="text_chunks", global_config=asdict(self)
118
+ )
119
+
120
+ self.llm_response_cache = (
121
+ self.key_string_value_json_storage_cls(
122
+ namespace="llm_response_cache", global_config=asdict(self)
123
+ )
124
+ if self.enable_llm_cache
125
+ else None
126
+ )
127
+ self.chunk_entity_relation_graph = self.graph_storage_cls(
128
+ namespace="chunk_entity_relation", global_config=asdict(self)
129
+ )
130
+ self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
131
+ self.embedding_func
132
+ )
133
+ self.entities_vdb = (
134
+ self.vector_db_storage_cls(
135
+ namespace="entities",
136
+ global_config=asdict(self),
137
+ embedding_func=self.embedding_func,
138
+ meta_fields={"entity_name"}
139
+ )
140
+ )
141
+ self.relationships_vdb = (
142
+ self.vector_db_storage_cls(
143
+ namespace="relationships",
144
+ global_config=asdict(self),
145
+ embedding_func=self.embedding_func,
146
+ meta_fields={"src_id", "tgt_id"}
147
+ )
148
+ )
149
+ self.chunks_vdb = (
150
+ self.vector_db_storage_cls(
151
+ namespace="chunks",
152
+ global_config=asdict(self),
153
+ embedding_func=self.embedding_func,
154
+ )
155
+ )
156
+
157
+ self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
158
+ partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
159
+ )
160
+
161
+ def insert(self, string_or_strings):
162
+ loop = always_get_an_event_loop()
163
+ return loop.run_until_complete(self.ainsert(string_or_strings))
164
+
165
+ async def ainsert(self, string_or_strings):
166
+ try:
167
+ if isinstance(string_or_strings, str):
168
+ string_or_strings = [string_or_strings]
169
+
170
+ new_docs = {
171
+ compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
172
+ for c in string_or_strings
173
+ }
174
+ _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
175
+ new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
176
+ if not len(new_docs):
177
+ logger.warning(f"All docs are already in the storage")
178
+ return
179
+ logger.info(f"[New Docs] inserting {len(new_docs)} docs")
180
+
181
+ inserting_chunks = {}
182
+ for doc_key, doc in new_docs.items():
183
+ chunks = {
184
+ compute_mdhash_id(dp["content"], prefix="chunk-"): {
185
+ **dp,
186
+ "full_doc_id": doc_key,
187
+ }
188
+ for dp in chunking_by_token_size(
189
+ doc["content"],
190
+ overlap_token_size=self.chunk_overlap_token_size,
191
+ max_token_size=self.chunk_token_size,
192
+ tiktoken_model=self.tiktoken_model_name,
193
+ )
194
+ }
195
+ inserting_chunks.update(chunks)
196
+ _add_chunk_keys = await self.text_chunks.filter_keys(
197
+ list(inserting_chunks.keys())
198
+ )
199
+ inserting_chunks = {
200
+ k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
201
+ }
202
+ if not len(inserting_chunks):
203
+ logger.warning(f"All chunks are already in the storage")
204
+ return
205
+ logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
206
+
207
+ await self.chunks_vdb.upsert(inserting_chunks)
208
+
209
+ logger.info("[Entity Extraction]...")
210
+ maybe_new_kg = await extract_entities(
211
+ inserting_chunks,
212
+ knwoledge_graph_inst=self.chunk_entity_relation_graph,
213
+ entity_vdb=self.entities_vdb,
214
+ relationships_vdb=self.relationships_vdb,
215
+ global_config=asdict(self),
216
+ )
217
+ if maybe_new_kg is None:
218
+ logger.warning("No new entities and relationships found")
219
+ return
220
+ self.chunk_entity_relation_graph = maybe_new_kg
221
+
222
+ await self.full_docs.upsert(new_docs)
223
+ await self.text_chunks.upsert(inserting_chunks)
224
+ finally:
225
+ await self._insert_done()
226
+
227
+ async def _insert_done(self):
228
+ tasks = []
229
+ for storage_inst in [
230
+ self.full_docs,
231
+ self.text_chunks,
232
+ self.llm_response_cache,
233
+ self.entities_vdb,
234
+ self.relationships_vdb,
235
+ self.chunks_vdb,
236
+ self.chunk_entity_relation_graph,
237
+ ]:
238
+ if storage_inst is None:
239
+ continue
240
+ tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
241
+ await asyncio.gather(*tasks)
242
+
243
+ def query(self, query: str, param: QueryParam = QueryParam()):
244
+ loop = always_get_an_event_loop()
245
+ return loop.run_until_complete(self.aquery(query, param))
246
+
247
+ async def aquery(self, query: str, param: QueryParam = QueryParam()):
248
+ if param.mode == "local":
249
+ response = await local_query(
250
+ query,
251
+ self.chunk_entity_relation_graph,
252
+ self.entities_vdb,
253
+ self.relationships_vdb,
254
+ self.text_chunks,
255
+ param,
256
+ asdict(self),
257
+ )
258
+ elif param.mode == "global":
259
+ response = await global_query(
260
+ query,
261
+ self.chunk_entity_relation_graph,
262
+ self.entities_vdb,
263
+ self.relationships_vdb,
264
+ self.text_chunks,
265
+ param,
266
+ asdict(self),
267
+ )
268
+ elif param.mode == "hybird":
269
+ response = await hybird_query(
270
+ query,
271
+ self.chunk_entity_relation_graph,
272
+ self.entities_vdb,
273
+ self.relationships_vdb,
274
+ self.text_chunks,
275
+ param,
276
+ asdict(self),
277
+ )
278
+ elif param.mode == "naive":
279
+ response = await naive_query(
280
+ query,
281
+ self.chunks_vdb,
282
+ self.text_chunks,
283
+ param,
284
+ asdict(self),
285
+ )
286
+ else:
287
+ raise ValueError(f"Unknown mode {param.mode}")
288
+ await self._query_done()
289
+ return response
290
+
291
+
292
+ async def _query_done(self):
293
+ tasks = []
294
+ for storage_inst in [self.llm_response_cache]:
295
+ if storage_inst is None:
296
+ continue
297
+ tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
298
+ await asyncio.gather(*tasks)
299
+
300
+
lightrag/llm.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
4
+ from tenacity import (
5
+ retry,
6
+ stop_after_attempt,
7
+ wait_exponential,
8
+ retry_if_exception_type,
9
+ )
10
+
11
+ from .base import BaseKVStorage
12
+ from .utils import compute_args_hash, wrap_embedding_func_with_attrs
13
+
14
+ @retry(
15
+ stop=stop_after_attempt(3),
16
+ wait=wait_exponential(multiplier=1, min=4, max=10),
17
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
18
+ )
19
+ async def openai_complete_if_cache(
20
+ model, prompt, system_prompt=None, history_messages=[], **kwargs
21
+ ) -> str:
22
+ openai_async_client = AsyncOpenAI()
23
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
24
+ messages = []
25
+ if system_prompt:
26
+ messages.append({"role": "system", "content": system_prompt})
27
+ messages.extend(history_messages)
28
+ messages.append({"role": "user", "content": prompt})
29
+ if hashing_kv is not None:
30
+ args_hash = compute_args_hash(model, messages)
31
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
32
+ if if_cache_return is not None:
33
+ return if_cache_return["return"]
34
+
35
+ response = await openai_async_client.chat.completions.create(
36
+ model=model, messages=messages, **kwargs
37
+ )
38
+
39
+ if hashing_kv is not None:
40
+ await hashing_kv.upsert(
41
+ {args_hash: {"return": response.choices[0].message.content, "model": model}}
42
+ )
43
+ return response.choices[0].message.content
44
+
45
+ async def gpt_4o_complete(
46
+ prompt, system_prompt=None, history_messages=[], **kwargs
47
+ ) -> str:
48
+ return await openai_complete_if_cache(
49
+ "gpt-4o",
50
+ prompt,
51
+ system_prompt=system_prompt,
52
+ history_messages=history_messages,
53
+ **kwargs,
54
+ )
55
+
56
+
57
+ async def gpt_4o_mini_complete(
58
+ prompt, system_prompt=None, history_messages=[], **kwargs
59
+ ) -> str:
60
+ return await openai_complete_if_cache(
61
+ "gpt-4o-mini",
62
+ prompt,
63
+ system_prompt=system_prompt,
64
+ history_messages=history_messages,
65
+ **kwargs,
66
+ )
67
+
68
+ @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
69
+ @retry(
70
+ stop=stop_after_attempt(3),
71
+ wait=wait_exponential(multiplier=1, min=4, max=10),
72
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
73
+ )
74
+ async def openai_embedding(texts: list[str]) -> np.ndarray:
75
+ openai_async_client = AsyncOpenAI()
76
+ response = await openai_async_client.embeddings.create(
77
+ model="text-embedding-3-small", input=texts, encoding_format="float"
78
+ )
79
+ return np.array([dp.embedding for dp in response.data])
80
+
81
+ if __name__ == "__main__":
82
+ import asyncio
83
+
84
+ async def main():
85
+ result = await gpt_4o_mini_complete('How are you?')
86
+ print(result)
87
+
88
+ asyncio.run(main())
lightrag/operate.py ADDED
@@ -0,0 +1,944 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import json
3
+ import re
4
+ from typing import Union
5
+ from collections import Counter, defaultdict
6
+
7
+ from .utils import (
8
+ logger,
9
+ clean_str,
10
+ compute_mdhash_id,
11
+ decode_tokens_by_tiktoken,
12
+ encode_string_by_tiktoken,
13
+ is_float_regex,
14
+ list_of_list_to_csv,
15
+ pack_user_ass_to_openai_messages,
16
+ split_string_by_multi_markers,
17
+ truncate_list_by_token_size,
18
+ )
19
+ from .base import (
20
+ BaseGraphStorage,
21
+ BaseKVStorage,
22
+ BaseVectorStorage,
23
+ TextChunkSchema,
24
+ QueryParam,
25
+ )
26
+ from .prompt import GRAPH_FIELD_SEP, PROMPTS
27
+
28
+ def chunking_by_token_size(
29
+ content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
30
+ ):
31
+ tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
32
+ results = []
33
+ for index, start in enumerate(
34
+ range(0, len(tokens), max_token_size - overlap_token_size)
35
+ ):
36
+ chunk_content = decode_tokens_by_tiktoken(
37
+ tokens[start : start + max_token_size], model_name=tiktoken_model
38
+ )
39
+ results.append(
40
+ {
41
+ "tokens": min(max_token_size, len(tokens) - start),
42
+ "content": chunk_content.strip(),
43
+ "chunk_order_index": index,
44
+ }
45
+ )
46
+ return results
47
+
48
+ async def _handle_entity_relation_summary(
49
+ entity_or_relation_name: str,
50
+ description: str,
51
+ global_config: dict,
52
+ ) -> str:
53
+ use_llm_func: callable = global_config["llm_model_func"]
54
+ llm_max_tokens = global_config["llm_model_max_token_size"]
55
+ tiktoken_model_name = global_config["tiktoken_model_name"]
56
+ summary_max_tokens = global_config["entity_summary_to_max_tokens"]
57
+
58
+ tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
59
+ if len(tokens) < summary_max_tokens: # No need for summary
60
+ return description
61
+ prompt_template = PROMPTS["summarize_entity_descriptions"]
62
+ use_description = decode_tokens_by_tiktoken(
63
+ tokens[:llm_max_tokens], model_name=tiktoken_model_name
64
+ )
65
+ context_base = dict(
66
+ entity_name=entity_or_relation_name,
67
+ description_list=use_description.split(GRAPH_FIELD_SEP),
68
+ )
69
+ use_prompt = prompt_template.format(**context_base)
70
+ logger.debug(f"Trigger summary: {entity_or_relation_name}")
71
+ summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
72
+ return summary
73
+
74
+
75
+ async def _handle_single_entity_extraction(
76
+ record_attributes: list[str],
77
+ chunk_key: str,
78
+ ):
79
+ if record_attributes[0] != '"entity"' or len(record_attributes) < 4:
80
+ return None
81
+ # add this record as a node in the G
82
+ entity_name = clean_str(record_attributes[1].upper())
83
+ if not entity_name.strip():
84
+ return None
85
+ entity_type = clean_str(record_attributes[2].upper())
86
+ entity_description = clean_str(record_attributes[3])
87
+ entity_source_id = chunk_key
88
+ return dict(
89
+ entity_name=entity_name,
90
+ entity_type=entity_type,
91
+ description=entity_description,
92
+ source_id=entity_source_id,
93
+ )
94
+
95
+
96
+ async def _handle_single_relationship_extraction(
97
+ record_attributes: list[str],
98
+ chunk_key: str,
99
+ ):
100
+ if record_attributes[0] != '"relationship"' or len(record_attributes) < 5:
101
+ return None
102
+ # add this record as edge
103
+ source = clean_str(record_attributes[1].upper())
104
+ target = clean_str(record_attributes[2].upper())
105
+ edge_description = clean_str(record_attributes[3])
106
+
107
+ edge_keywords = clean_str(record_attributes[4])
108
+ edge_source_id = chunk_key
109
+ weight = (
110
+ float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
111
+ )
112
+ return dict(
113
+ src_id=source,
114
+ tgt_id=target,
115
+ weight=weight,
116
+ description=edge_description,
117
+ keywords=edge_keywords,
118
+ source_id=edge_source_id,
119
+ )
120
+
121
+
122
+ async def _merge_nodes_then_upsert(
123
+ entity_name: str,
124
+ nodes_data: list[dict],
125
+ knwoledge_graph_inst: BaseGraphStorage,
126
+ global_config: dict,
127
+ ):
128
+ already_entitiy_types = []
129
+ already_source_ids = []
130
+ already_description = []
131
+
132
+ already_node = await knwoledge_graph_inst.get_node(entity_name)
133
+ if already_node is not None:
134
+ already_entitiy_types.append(already_node["entity_type"])
135
+ already_source_ids.extend(
136
+ split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
137
+ )
138
+ already_description.append(already_node["description"])
139
+
140
+ entity_type = sorted(
141
+ Counter(
142
+ [dp["entity_type"] for dp in nodes_data] + already_entitiy_types
143
+ ).items(),
144
+ key=lambda x: x[1],
145
+ reverse=True,
146
+ )[0][0]
147
+ description = GRAPH_FIELD_SEP.join(
148
+ sorted(set([dp["description"] for dp in nodes_data] + already_description))
149
+ )
150
+ source_id = GRAPH_FIELD_SEP.join(
151
+ set([dp["source_id"] for dp in nodes_data] + already_source_ids)
152
+ )
153
+ description = await _handle_entity_relation_summary(
154
+ entity_name, description, global_config
155
+ )
156
+ node_data = dict(
157
+ entity_type=entity_type,
158
+ description=description,
159
+ source_id=source_id,
160
+ )
161
+ await knwoledge_graph_inst.upsert_node(
162
+ entity_name,
163
+ node_data=node_data,
164
+ )
165
+ node_data["entity_name"] = entity_name
166
+ return node_data
167
+
168
+
169
+ async def _merge_edges_then_upsert(
170
+ src_id: str,
171
+ tgt_id: str,
172
+ edges_data: list[dict],
173
+ knwoledge_graph_inst: BaseGraphStorage,
174
+ global_config: dict,
175
+ ):
176
+ already_weights = []
177
+ already_source_ids = []
178
+ already_description = []
179
+ already_keywords = []
180
+
181
+ if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
182
+ already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
183
+ already_weights.append(already_edge["weight"])
184
+ already_source_ids.extend(
185
+ split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
186
+ )
187
+ already_description.append(already_edge["description"])
188
+ already_keywords.extend(
189
+ split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
190
+ )
191
+
192
+ weight = sum([dp["weight"] for dp in edges_data] + already_weights)
193
+ description = GRAPH_FIELD_SEP.join(
194
+ sorted(set([dp["description"] for dp in edges_data] + already_description))
195
+ )
196
+ keywords = GRAPH_FIELD_SEP.join(
197
+ sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
198
+ )
199
+ source_id = GRAPH_FIELD_SEP.join(
200
+ set([dp["source_id"] for dp in edges_data] + already_source_ids)
201
+ )
202
+ for need_insert_id in [src_id, tgt_id]:
203
+ if not (await knwoledge_graph_inst.has_node(need_insert_id)):
204
+ await knwoledge_graph_inst.upsert_node(
205
+ need_insert_id,
206
+ node_data={
207
+ "source_id": source_id,
208
+ "description": description,
209
+ "entity_type": '"UNKNOWN"',
210
+ },
211
+ )
212
+ description = await _handle_entity_relation_summary(
213
+ (src_id, tgt_id), description, global_config
214
+ )
215
+ await knwoledge_graph_inst.upsert_edge(
216
+ src_id,
217
+ tgt_id,
218
+ edge_data=dict(
219
+ weight=weight,
220
+ description=description,
221
+ keywords=keywords,
222
+ source_id=source_id,
223
+ ),
224
+ )
225
+
226
+ edge_data = dict(
227
+ src_id=src_id,
228
+ tgt_id=tgt_id,
229
+ description=description,
230
+ keywords=keywords,
231
+ )
232
+
233
+ return edge_data
234
+
235
+ async def extract_entities(
236
+ chunks: dict[str, TextChunkSchema],
237
+ knwoledge_graph_inst: BaseGraphStorage,
238
+ entity_vdb: BaseVectorStorage,
239
+ relationships_vdb: BaseVectorStorage,
240
+ global_config: dict,
241
+ ) -> Union[BaseGraphStorage, None]:
242
+ use_llm_func: callable = global_config["llm_model_func"]
243
+ entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
244
+
245
+ ordered_chunks = list(chunks.items())
246
+
247
+ entity_extract_prompt = PROMPTS["entity_extraction"]
248
+ context_base = dict(
249
+ tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
250
+ record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
251
+ completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
252
+ entity_types=",".join(PROMPTS["DEFAULT_ENTITY_TYPES"]),
253
+ )
254
+ continue_prompt = PROMPTS["entiti_continue_extraction"]
255
+ if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
256
+
257
+ already_processed = 0
258
+ already_entities = 0
259
+ already_relations = 0
260
+
261
+ async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
262
+ nonlocal already_processed, already_entities, already_relations
263
+ chunk_key = chunk_key_dp[0]
264
+ chunk_dp = chunk_key_dp[1]
265
+ content = chunk_dp["content"]
266
+ hint_prompt = entity_extract_prompt.format(**context_base, input_text=content)
267
+ final_result = await use_llm_func(hint_prompt)
268
+
269
+ history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
270
+ for now_glean_index in range(entity_extract_max_gleaning):
271
+ glean_result = await use_llm_func(continue_prompt, history_messages=history)
272
+
273
+ history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
274
+ final_result += glean_result
275
+ if now_glean_index == entity_extract_max_gleaning - 1:
276
+ break
277
+
278
+ if_loop_result: str = await use_llm_func(
279
+ if_loop_prompt, history_messages=history
280
+ )
281
+ if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
282
+ if if_loop_result != "yes":
283
+ break
284
+
285
+ records = split_string_by_multi_markers(
286
+ final_result,
287
+ [context_base["record_delimiter"], context_base["completion_delimiter"]],
288
+ )
289
+
290
+ maybe_nodes = defaultdict(list)
291
+ maybe_edges = defaultdict(list)
292
+ for record in records:
293
+ record = re.search(r"\((.*)\)", record)
294
+ if record is None:
295
+ continue
296
+ record = record.group(1)
297
+ record_attributes = split_string_by_multi_markers(
298
+ record, [context_base["tuple_delimiter"]]
299
+ )
300
+ if_entities = await _handle_single_entity_extraction(
301
+ record_attributes, chunk_key
302
+ )
303
+ if if_entities is not None:
304
+ maybe_nodes[if_entities["entity_name"]].append(if_entities)
305
+ continue
306
+
307
+ if_relation = await _handle_single_relationship_extraction(
308
+ record_attributes, chunk_key
309
+ )
310
+ if if_relation is not None:
311
+ maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
312
+ if_relation
313
+ )
314
+ already_processed += 1
315
+ already_entities += len(maybe_nodes)
316
+ already_relations += len(maybe_edges)
317
+ now_ticks = PROMPTS["process_tickers"][
318
+ already_processed % len(PROMPTS["process_tickers"])
319
+ ]
320
+ print(
321
+ f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
322
+ end="",
323
+ flush=True,
324
+ )
325
+ return dict(maybe_nodes), dict(maybe_edges)
326
+
327
+ # use_llm_func is wrapped in ascynio.Semaphore, limiting max_async callings
328
+ results = await asyncio.gather(
329
+ *[_process_single_content(c) for c in ordered_chunks]
330
+ )
331
+ print() # clear the progress bar
332
+ maybe_nodes = defaultdict(list)
333
+ maybe_edges = defaultdict(list)
334
+ for m_nodes, m_edges in results:
335
+ for k, v in m_nodes.items():
336
+ maybe_nodes[k].extend(v)
337
+ for k, v in m_edges.items():
338
+ maybe_edges[tuple(sorted(k))].extend(v)
339
+ all_entities_data = await asyncio.gather(
340
+ *[
341
+ _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
342
+ for k, v in maybe_nodes.items()
343
+ ]
344
+ )
345
+ all_relationships_data = await asyncio.gather(
346
+ *[
347
+ _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
348
+ for k, v in maybe_edges.items()
349
+ ]
350
+ )
351
+ if not len(all_entities_data):
352
+ logger.warning("Didn't extract any entities, maybe your LLM is not working")
353
+ return None
354
+ if not len(all_relationships_data):
355
+ logger.warning("Didn't extract any relationships, maybe your LLM is not working")
356
+ return None
357
+
358
+ if entity_vdb is not None:
359
+ data_for_vdb = {
360
+ compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
361
+ "content": dp["entity_name"] + dp["description"],
362
+ "entity_name": dp["entity_name"],
363
+ }
364
+ for dp in all_entities_data
365
+ }
366
+ await entity_vdb.upsert(data_for_vdb)
367
+
368
+ if relationships_vdb is not None:
369
+ data_for_vdb = {
370
+ compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
371
+ "src_id": dp["src_id"],
372
+ "tgt_id": dp["tgt_id"],
373
+ "content": dp["keywords"] + dp["src_id"] + dp["tgt_id"] + dp["description"],
374
+ }
375
+ for dp in all_relationships_data
376
+ }
377
+ await relationships_vdb.upsert(data_for_vdb)
378
+
379
+ return knwoledge_graph_inst
380
+
381
+ async def local_query(
382
+ query,
383
+ knowledge_graph_inst: BaseGraphStorage,
384
+ entities_vdb: BaseVectorStorage,
385
+ relationships_vdb: BaseVectorStorage,
386
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
387
+ query_param: QueryParam,
388
+ global_config: dict,
389
+ ) -> str:
390
+ use_model_func = global_config["llm_model_func"]
391
+
392
+ kw_prompt_temp = PROMPTS["keywords_extraction"]
393
+ kw_prompt = kw_prompt_temp.format(query=query)
394
+ result = await use_model_func(kw_prompt)
395
+
396
+ try:
397
+ keywords_data = json.loads(result)
398
+ keywords = keywords_data.get("low_level_keywords", [])
399
+ keywords = ', '.join(keywords)
400
+ except json.JSONDecodeError as e:
401
+ # Handle parsing error
402
+ print(f"JSON parsing error: {e}")
403
+ return PROMPTS["fail_response"]
404
+
405
+ context = await _build_local_query_context(
406
+ keywords,
407
+ knowledge_graph_inst,
408
+ entities_vdb,
409
+ text_chunks_db,
410
+ query_param,
411
+ )
412
+ if query_param.only_need_context:
413
+ return context
414
+ if context is None:
415
+ return PROMPTS["fail_response"]
416
+ sys_prompt_temp = PROMPTS["rag_response"]
417
+ sys_prompt = sys_prompt_temp.format(
418
+ context_data=context, response_type=query_param.response_type
419
+ )
420
+ response = await use_model_func(
421
+ query,
422
+ system_prompt=sys_prompt,
423
+ )
424
+ return response
425
+
426
+ async def _build_local_query_context(
427
+ query,
428
+ knowledge_graph_inst: BaseGraphStorage,
429
+ entities_vdb: BaseVectorStorage,
430
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
431
+ query_param: QueryParam,
432
+ ):
433
+ results = await entities_vdb.query(query, top_k=query_param.top_k)
434
+ if not len(results):
435
+ return None
436
+ node_datas = await asyncio.gather(
437
+ *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
438
+ )
439
+ if not all([n is not None for n in node_datas]):
440
+ logger.warning("Some nodes are missing, maybe the storage is damaged")
441
+ node_degrees = await asyncio.gather(
442
+ *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
443
+ )
444
+ node_datas = [
445
+ {**n, "entity_name": k["entity_name"], "rank": d}
446
+ for k, n, d in zip(results, node_datas, node_degrees)
447
+ if n is not None
448
+ ]
449
+ use_text_units = await _find_most_related_text_unit_from_entities(
450
+ node_datas, query_param, text_chunks_db, knowledge_graph_inst
451
+ )
452
+ use_relations = await _find_most_related_edges_from_entities(
453
+ node_datas, query_param, knowledge_graph_inst
454
+ )
455
+ logger.info(
456
+ f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
457
+ )
458
+ entites_section_list = [["id", "entity", "type", "description", "rank"]]
459
+ for i, n in enumerate(node_datas):
460
+ entites_section_list.append(
461
+ [
462
+ i,
463
+ n["entity_name"],
464
+ n.get("entity_type", "UNKNOWN"),
465
+ n.get("description", "UNKNOWN"),
466
+ n["rank"],
467
+ ]
468
+ )
469
+ entities_context = list_of_list_to_csv(entites_section_list)
470
+
471
+ relations_section_list = [
472
+ ["id", "source", "target", "description", "keywords", "weight", "rank"]
473
+ ]
474
+ for i, e in enumerate(use_relations):
475
+ relations_section_list.append(
476
+ [
477
+ i,
478
+ e["src_tgt"][0],
479
+ e["src_tgt"][1],
480
+ e["description"],
481
+ e["keywords"],
482
+ e["weight"],
483
+ e["rank"],
484
+ ]
485
+ )
486
+ relations_context = list_of_list_to_csv(relations_section_list)
487
+
488
+ text_units_section_list = [["id", "content"]]
489
+ for i, t in enumerate(use_text_units):
490
+ text_units_section_list.append([i, t["content"]])
491
+ text_units_context = list_of_list_to_csv(text_units_section_list)
492
+ return f"""
493
+ -----Entities-----
494
+ ```csv
495
+ {entities_context}
496
+ ```
497
+ -----Relationships-----
498
+ ```csv
499
+ {relations_context}
500
+ ```
501
+ -----Sources-----
502
+ ```csv
503
+ {text_units_context}
504
+ ```
505
+ """
506
+
507
+ async def _find_most_related_text_unit_from_entities(
508
+ node_datas: list[dict],
509
+ query_param: QueryParam,
510
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
511
+ knowledge_graph_inst: BaseGraphStorage,
512
+ ):
513
+ text_units = [
514
+ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
515
+ for dp in node_datas
516
+ ]
517
+ edges = await asyncio.gather(
518
+ *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
519
+ )
520
+ all_one_hop_nodes = set()
521
+ for this_edges in edges:
522
+ if not this_edges:
523
+ continue
524
+ all_one_hop_nodes.update([e[1] for e in this_edges])
525
+ all_one_hop_nodes = list(all_one_hop_nodes)
526
+ all_one_hop_nodes_data = await asyncio.gather(
527
+ *[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
528
+ )
529
+ all_one_hop_text_units_lookup = {
530
+ k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
531
+ for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
532
+ if v is not None
533
+ }
534
+ all_text_units_lookup = {}
535
+ for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
536
+ for c_id in this_text_units:
537
+ if c_id in all_text_units_lookup:
538
+ continue
539
+ relation_counts = 0
540
+ for e in this_edges:
541
+ if (
542
+ e[1] in all_one_hop_text_units_lookup
543
+ and c_id in all_one_hop_text_units_lookup[e[1]]
544
+ ):
545
+ relation_counts += 1
546
+ all_text_units_lookup[c_id] = {
547
+ "data": await text_chunks_db.get_by_id(c_id),
548
+ "order": index,
549
+ "relation_counts": relation_counts,
550
+ }
551
+ if any([v is None for v in all_text_units_lookup.values()]):
552
+ logger.warning("Text chunks are missing, maybe the storage is damaged")
553
+ all_text_units = [
554
+ {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
555
+ ]
556
+ all_text_units = sorted(
557
+ all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
558
+ )
559
+ all_text_units = truncate_list_by_token_size(
560
+ all_text_units,
561
+ key=lambda x: x["data"]["content"],
562
+ max_token_size=query_param.max_token_for_text_unit,
563
+ )
564
+ all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
565
+ return all_text_units
566
+
567
+ async def _find_most_related_edges_from_entities(
568
+ node_datas: list[dict],
569
+ query_param: QueryParam,
570
+ knowledge_graph_inst: BaseGraphStorage,
571
+ ):
572
+ all_related_edges = await asyncio.gather(
573
+ *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
574
+ )
575
+ all_edges = set()
576
+ for this_edges in all_related_edges:
577
+ all_edges.update([tuple(sorted(e)) for e in this_edges])
578
+ all_edges = list(all_edges)
579
+ all_edges_pack = await asyncio.gather(
580
+ *[knowledge_graph_inst.get_edge(e[0], e[1]) for e in all_edges]
581
+ )
582
+ all_edges_degree = await asyncio.gather(
583
+ *[knowledge_graph_inst.edge_degree(e[0], e[1]) for e in all_edges]
584
+ )
585
+ all_edges_data = [
586
+ {"src_tgt": k, "rank": d, **v}
587
+ for k, v, d in zip(all_edges, all_edges_pack, all_edges_degree)
588
+ if v is not None
589
+ ]
590
+ all_edges_data = sorted(
591
+ all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
592
+ )
593
+ all_edges_data = truncate_list_by_token_size(
594
+ all_edges_data,
595
+ key=lambda x: x["description"],
596
+ max_token_size=query_param.max_token_for_global_context,
597
+ )
598
+ return all_edges_data
599
+
600
+ async def global_query(
601
+ query,
602
+ knowledge_graph_inst: BaseGraphStorage,
603
+ entities_vdb: BaseVectorStorage,
604
+ relationships_vdb: BaseVectorStorage,
605
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
606
+ query_param: QueryParam,
607
+ global_config: dict,
608
+ ) -> str:
609
+ use_model_func = global_config["llm_model_func"]
610
+
611
+ kw_prompt_temp = PROMPTS["keywords_extraction"]
612
+ kw_prompt = kw_prompt_temp.format(query=query)
613
+ result = await use_model_func(kw_prompt)
614
+
615
+ try:
616
+ keywords_data = json.loads(result)
617
+ keywords = keywords_data.get("high_level_keywords", [])
618
+ keywords = ', '.join(keywords)
619
+ except json.JSONDecodeError as e:
620
+ # Handle parsing error
621
+ print(f"JSON parsing error: {e}")
622
+ return PROMPTS["fail_response"]
623
+
624
+ context = await _build_global_query_context(
625
+ keywords,
626
+ knowledge_graph_inst,
627
+ entities_vdb,
628
+ relationships_vdb,
629
+ text_chunks_db,
630
+ query_param,
631
+ )
632
+
633
+ if query_param.only_need_context:
634
+ return context
635
+ if context is None:
636
+ return PROMPTS["fail_response"]
637
+
638
+ sys_prompt_temp = PROMPTS["rag_response"]
639
+ sys_prompt = sys_prompt_temp.format(
640
+ context_data=context, response_type=query_param.response_type
641
+ )
642
+ response = await use_model_func(
643
+ query,
644
+ system_prompt=sys_prompt,
645
+ )
646
+ return response
647
+
648
+ async def _build_global_query_context(
649
+ keywords,
650
+ knowledge_graph_inst: BaseGraphStorage,
651
+ entities_vdb: BaseVectorStorage,
652
+ relationships_vdb: BaseVectorStorage,
653
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
654
+ query_param: QueryParam,
655
+ ):
656
+ results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
657
+
658
+ if not len(results):
659
+ return None
660
+
661
+ edge_datas = await asyncio.gather(
662
+ *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
663
+ )
664
+
665
+ if not all([n is not None for n in edge_datas]):
666
+ logger.warning("Some edges are missing, maybe the storage is damaged")
667
+ edge_degree = await asyncio.gather(
668
+ *[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
669
+ )
670
+ edge_datas = [
671
+ {"src_id": k["src_id"], "tgt_id": k["tgt_id"], "rank": d, **v}
672
+ for k, v, d in zip(results, edge_datas, edge_degree)
673
+ if v is not None
674
+ ]
675
+ edge_datas = sorted(
676
+ edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
677
+ )
678
+ edge_datas = truncate_list_by_token_size(
679
+ edge_datas,
680
+ key=lambda x: x["description"],
681
+ max_token_size=query_param.max_token_for_global_context,
682
+ )
683
+
684
+ use_entities = await _find_most_related_entities_from_relationships(
685
+ edge_datas, query_param, knowledge_graph_inst
686
+ )
687
+ use_text_units = await _find_related_text_unit_from_relationships(
688
+ edge_datas, query_param, text_chunks_db, knowledge_graph_inst
689
+ )
690
+ logger.info(
691
+ f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
692
+ )
693
+ relations_section_list = [
694
+ ["id", "source", "target", "description", "keywords", "weight", "rank"]
695
+ ]
696
+ for i, e in enumerate(edge_datas):
697
+ relations_section_list.append(
698
+ [
699
+ i,
700
+ e["src_id"],
701
+ e["tgt_id"],
702
+ e["description"],
703
+ e["keywords"],
704
+ e["weight"],
705
+ e["rank"],
706
+ ]
707
+ )
708
+ relations_context = list_of_list_to_csv(relations_section_list)
709
+
710
+ entites_section_list = [["id", "entity", "type", "description", "rank"]]
711
+ for i, n in enumerate(use_entities):
712
+ entites_section_list.append(
713
+ [
714
+ i,
715
+ n["entity_name"],
716
+ n.get("entity_type", "UNKNOWN"),
717
+ n.get("description", "UNKNOWN"),
718
+ n["rank"],
719
+ ]
720
+ )
721
+ entities_context = list_of_list_to_csv(entites_section_list)
722
+
723
+ text_units_section_list = [["id", "content"]]
724
+ for i, t in enumerate(use_text_units):
725
+ text_units_section_list.append([i, t["content"]])
726
+ text_units_context = list_of_list_to_csv(text_units_section_list)
727
+
728
+ return f"""
729
+ -----Entities-----
730
+ ```csv
731
+ {entities_context}
732
+ ```
733
+ -----Relationships-----
734
+ ```csv
735
+ {relations_context}
736
+ ```
737
+ -----Sources-----
738
+ ```csv
739
+ {text_units_context}
740
+ ```
741
+ """
742
+
743
+ async def _find_most_related_entities_from_relationships(
744
+ edge_datas: list[dict],
745
+ query_param: QueryParam,
746
+ knowledge_graph_inst: BaseGraphStorage,
747
+ ):
748
+ entity_names = set()
749
+ for e in edge_datas:
750
+ entity_names.add(e["src_id"])
751
+ entity_names.add(e["tgt_id"])
752
+
753
+ node_datas = await asyncio.gather(
754
+ *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
755
+ )
756
+
757
+ node_degrees = await asyncio.gather(
758
+ *[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names]
759
+ )
760
+ node_datas = [
761
+ {**n, "entity_name": k, "rank": d}
762
+ for k, n, d in zip(entity_names, node_datas, node_degrees)
763
+ ]
764
+
765
+ node_datas = truncate_list_by_token_size(
766
+ node_datas,
767
+ key=lambda x: x["description"],
768
+ max_token_size=query_param.max_token_for_local_context,
769
+ )
770
+
771
+ return node_datas
772
+
773
+ async def _find_related_text_unit_from_relationships(
774
+ edge_datas: list[dict],
775
+ query_param: QueryParam,
776
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
777
+ knowledge_graph_inst: BaseGraphStorage,
778
+ ):
779
+
780
+ text_units = [
781
+ split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
782
+ for dp in edge_datas
783
+ ]
784
+
785
+ all_text_units_lookup = {}
786
+
787
+ for index, unit_list in enumerate(text_units):
788
+ for c_id in unit_list:
789
+ if c_id not in all_text_units_lookup:
790
+ all_text_units_lookup[c_id] = {
791
+ "data": await text_chunks_db.get_by_id(c_id),
792
+ "order": index,
793
+ }
794
+
795
+ if any([v is None for v in all_text_units_lookup.values()]):
796
+ logger.warning("Text chunks are missing, maybe the storage is damaged")
797
+ all_text_units = [
798
+ {"id": k, **v} for k, v in all_text_units_lookup.items() if v is not None
799
+ ]
800
+ all_text_units = sorted(
801
+ all_text_units, key=lambda x: x["order"]
802
+ )
803
+ all_text_units = truncate_list_by_token_size(
804
+ all_text_units,
805
+ key=lambda x: x["data"]["content"],
806
+ max_token_size=query_param.max_token_for_text_unit,
807
+ )
808
+ all_text_units: list[TextChunkSchema] = [t["data"] for t in all_text_units]
809
+
810
+ return all_text_units
811
+
812
+ async def hybird_query(
813
+ query,
814
+ knowledge_graph_inst: BaseGraphStorage,
815
+ entities_vdb: BaseVectorStorage,
816
+ relationships_vdb: BaseVectorStorage,
817
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
818
+ query_param: QueryParam,
819
+ global_config: dict,
820
+ ) -> str:
821
+ use_model_func = global_config["llm_model_func"]
822
+
823
+ kw_prompt_temp = PROMPTS["keywords_extraction"]
824
+ kw_prompt = kw_prompt_temp.format(query=query)
825
+ result = await use_model_func(kw_prompt)
826
+
827
+ try:
828
+ keywords_data = json.loads(result)
829
+ hl_keywords = keywords_data.get("high_level_keywords", [])
830
+ ll_keywords = keywords_data.get("low_level_keywords", [])
831
+ hl_keywords = ', '.join(hl_keywords)
832
+ ll_keywords = ', '.join(ll_keywords)
833
+ except json.JSONDecodeError as e:
834
+ # Handle parsing error
835
+ print(f"JSON parsing error: {e}")
836
+ return PROMPTS["fail_response"]
837
+
838
+ low_level_context = await _build_local_query_context(
839
+ ll_keywords,
840
+ knowledge_graph_inst,
841
+ entities_vdb,
842
+ text_chunks_db,
843
+ query_param,
844
+ )
845
+
846
+ high_level_context = await _build_global_query_context(
847
+ hl_keywords,
848
+ knowledge_graph_inst,
849
+ entities_vdb,
850
+ relationships_vdb,
851
+ text_chunks_db,
852
+ query_param,
853
+ )
854
+
855
+ context = combine_contexts(high_level_context, low_level_context)
856
+
857
+ if query_param.only_need_context:
858
+ return context
859
+ if context is None:
860
+ return PROMPTS["fail_response"]
861
+
862
+ sys_prompt_temp = PROMPTS["rag_response"]
863
+ sys_prompt = sys_prompt_temp.format(
864
+ context_data=context, response_type=query_param.response_type
865
+ )
866
+ response = await use_model_func(
867
+ query,
868
+ system_prompt=sys_prompt,
869
+ )
870
+ return response
871
+
872
+ def combine_contexts(high_level_context, low_level_context):
873
+ # Function to extract entities, relationships, and sources from context strings
874
+ def extract_sections(context):
875
+ entities_match = re.search(r'-----Entities-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
876
+ relationships_match = re.search(r'-----Relationships-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
877
+ sources_match = re.search(r'-----Sources-----\s*```csv\s*(.*?)\s*```', context, re.DOTALL)
878
+
879
+ entities = entities_match.group(1) if entities_match else ''
880
+ relationships = relationships_match.group(1) if relationships_match else ''
881
+ sources = sources_match.group(1) if sources_match else ''
882
+
883
+ return entities, relationships, sources
884
+
885
+ # Extract sections from both contexts
886
+ hl_entities, hl_relationships, hl_sources = extract_sections(high_level_context)
887
+ ll_entities, ll_relationships, ll_sources = extract_sections(low_level_context)
888
+
889
+ # Combine and deduplicate the entities
890
+ combined_entities_set = set(filter(None, hl_entities.strip().split('\n') + ll_entities.strip().split('\n')))
891
+ combined_entities = '\n'.join(combined_entities_set)
892
+
893
+ # Combine and deduplicate the relationships
894
+ combined_relationships_set = set(filter(None, hl_relationships.strip().split('\n') + ll_relationships.strip().split('\n')))
895
+ combined_relationships = '\n'.join(combined_relationships_set)
896
+
897
+ # Combine and deduplicate the sources
898
+ combined_sources_set = set(filter(None, hl_sources.strip().split('\n') + ll_sources.strip().split('\n')))
899
+ combined_sources = '\n'.join(combined_sources_set)
900
+
901
+ # Format the combined context
902
+ return f"""
903
+ -----Entities-----
904
+ ```csv
905
+ {combined_entities}
906
+ -----Relationships-----
907
+ {combined_relationships}
908
+ -----Sources-----
909
+ {combined_sources}
910
+ """
911
+
912
+ async def naive_query(
913
+ query,
914
+ chunks_vdb: BaseVectorStorage,
915
+ text_chunks_db: BaseKVStorage[TextChunkSchema],
916
+ query_param: QueryParam,
917
+ global_config: dict,
918
+ ):
919
+ use_model_func = global_config["llm_model_func"]
920
+ results = await chunks_vdb.query(query, top_k=query_param.top_k)
921
+ if not len(results):
922
+ return PROMPTS["fail_response"]
923
+ chunks_ids = [r["id"] for r in results]
924
+ chunks = await text_chunks_db.get_by_ids(chunks_ids)
925
+
926
+ maybe_trun_chunks = truncate_list_by_token_size(
927
+ chunks,
928
+ key=lambda x: x["content"],
929
+ max_token_size=query_param.max_token_for_text_unit,
930
+ )
931
+ logger.info(f"Truncate {len(chunks)} to {len(maybe_trun_chunks)} chunks")
932
+ section = "--New Chunk--\n".join([c["content"] for c in maybe_trun_chunks])
933
+ if query_param.only_need_context:
934
+ return section
935
+ sys_prompt_temp = PROMPTS["naive_rag_response"]
936
+ sys_prompt = sys_prompt_temp.format(
937
+ content_data=section, response_type=query_param.response_type
938
+ )
939
+ response = await use_model_func(
940
+ query,
941
+ system_prompt=sys_prompt,
942
+ )
943
+ return response
944
+
lightrag/prompt.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GRAPH_FIELD_SEP = "<SEP>"
2
+
3
+ PROMPTS = {}
4
+
5
+ PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
6
+ PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
7
+ PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
8
+ PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
9
+
10
+ PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event"]
11
+
12
+ PROMPTS[
13
+ "entity_extraction"
14
+ ] = """-Goal-
15
+ Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
16
+
17
+ -Steps-
18
+ 1. Identify all entities. For each identified entity, extract the following information:
19
+ - entity_name: Name of the entity, capitalized
20
+ - entity_type: One of the following types: [{entity_types}]
21
+ - entity_description: Comprehensive description of the entity's attributes and activities
22
+ Format each entity as ("entity"{tuple_delimiter}<entity_name>{tuple_delimiter}<entity_type>{tuple_delimiter}<entity_description>
23
+
24
+ 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
25
+ For each pair of related entities, extract the following information:
26
+ - source_entity: name of the source entity, as identified in step 1
27
+ - target_entity: name of the target entity, as identified in step 1
28
+ - relationship_description: explanation as to why you think the source entity and the target entity are related to each other
29
+ - relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
30
+ - relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
31
+ Format each relationship as ("relationship"{tuple_delimiter}<source_entity>{tuple_delimiter}<target_entity>{tuple_delimiter}<relationship_description>{tuple_delimiter}<relationship_keywords>{tuple_delimiter}<relationship_strength>)
32
+
33
+ 3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
34
+ Format the content-level key words as ("content_keywords"{tuple_delimiter}<high_level_keywords>)
35
+
36
+ 4. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
37
+
38
+ 5. When finished, output {completion_delimiter}
39
+
40
+ ######################
41
+ -Examples-
42
+ ######################
43
+ Example 1:
44
+
45
+ Entity_types: [person, technology, mission, organization, location]
46
+ Text:
47
+ while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
48
+
49
+ Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.”
50
+
51
+ The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
52
+
53
+ It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
54
+ ################
55
+ Output:
56
+ ("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
57
+ ("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
58
+ ("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
59
+ ("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
60
+ ("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
61
+ ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter}
62
+ ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter}
63
+ ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter}
64
+ ("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter}
65
+ ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
66
+ ("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
67
+ #############################
68
+ Example 2:
69
+
70
+ Entity_types: [person, technology, mission, organization, location]
71
+ Text:
72
+ They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve.
73
+
74
+ Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril.
75
+
76
+ Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly
77
+ #############
78
+ Output:
79
+ ("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter}
80
+ ("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter}
81
+ ("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter}
82
+ ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter}
83
+ ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter}
84
+ ("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter}
85
+ #############################
86
+ Example 3:
87
+
88
+ Entity_types: [person, role, technology, organization, event, location, concept]
89
+ Text:
90
+ their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data.
91
+
92
+ "It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning."
93
+
94
+ Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back."
95
+
96
+ Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history.
97
+
98
+ The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation
99
+ #############
100
+ Output:
101
+ ("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter}
102
+ ("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter}
103
+ ("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter}
104
+ ("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter}
105
+ ("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter}
106
+ ("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter}
107
+ ("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}"communication, learning process"{tuple_delimiter}9){record_delimiter}
108
+ ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}"leadership, exploration"{tuple_delimiter}10){record_delimiter}
109
+ ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter}
110
+ ("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter}
111
+ ("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter}
112
+ #############################
113
+ -Real Data-
114
+ ######################
115
+ Entity_types: {entity_types}
116
+ Text: {input_text}
117
+ ######################
118
+ Output:
119
+ """
120
+
121
+ PROMPTS[
122
+ "summarize_entity_descriptions"
123
+ ] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
124
+ Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
125
+ Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
126
+ If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
127
+ Make sure it is written in third person, and include the entity names so we the have full context.
128
+
129
+ #######
130
+ -Data-
131
+ Entities: {entity_name}
132
+ Description List: {description_list}
133
+ #######
134
+ Output:
135
+ """
136
+
137
+ PROMPTS[
138
+ "entiti_continue_extraction"
139
+ ] = """MANY entities were missed in the last extraction. Add them below using the same format:
140
+ """
141
+
142
+ PROMPTS[
143
+ "entiti_if_loop_extraction"
144
+ ] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
145
+ """
146
+
147
+ PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
148
+
149
+ PROMPTS[
150
+ "rag_response"
151
+ ] = """---Role---
152
+
153
+ You are a helpful assistant responding to questions about data in the tables provided.
154
+
155
+
156
+ ---Goal---
157
+
158
+ Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
159
+ If you don't know the answer, just say so. Do not make anything up.
160
+ Do not include information where the supporting evidence for it is not provided.
161
+
162
+ ---Target response length and format---
163
+
164
+ {response_type}
165
+
166
+
167
+ ---Data tables---
168
+
169
+ {context_data}
170
+
171
+
172
+ ---Goal---
173
+
174
+ Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
175
+
176
+ If you don't know the answer, just say so. Do not make anything up.
177
+
178
+ Do not include information where the supporting evidence for it is not provided.
179
+
180
+
181
+ ---Target response length and format---
182
+
183
+ {response_type}
184
+
185
+ Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
186
+ """
187
+
188
+ PROMPTS["keywords_extraction"] = """---Role---
189
+
190
+ You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query.
191
+
192
+ ---Goal---
193
+
194
+ Given the query, list both high-level and low-level keywords. High-level keywords focus on overarching concepts or themes, while low-level keywords focus on specific entities, details, or concrete terms.
195
+
196
+ ---Instructions---
197
+
198
+ - Output the keywords in JSON format.
199
+ - The JSON should have two keys:
200
+ - "high_level_keywords" for overarching concepts or themes.
201
+ - "low_level_keywords" for specific entities or details.
202
+
203
+ ######################
204
+ -Examples-
205
+ ######################
206
+ Example 1:
207
+
208
+ Query: "How does international trade influence global economic stability?"
209
+ ################
210
+ Output:
211
+ {{
212
+ "high_level_keywords": ["International trade", "Global economic stability", "Economic impact"],
213
+ "low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
214
+ }}
215
+ #############################
216
+ Example 2:
217
+
218
+ Query: "What are the environmental consequences of deforestation on biodiversity?"
219
+ ################
220
+ Output:
221
+ {{
222
+ "high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
223
+ "low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
224
+ }}
225
+ #############################
226
+ Example 3:
227
+
228
+ Query: "What is the role of education in reducing poverty?"
229
+ ################
230
+ Output:
231
+ {{
232
+ "high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
233
+ "low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
234
+ }}
235
+ #############################
236
+ -Real Data-
237
+ ######################
238
+ Query: {query}
239
+ ######################
240
+ Output:
241
+
242
+ """
243
+
244
+ PROMPTS[
245
+ "naive_rag_response"
246
+ ] = """You're a helpful assistant
247
+ Below are the knowledge you know:
248
+ {content_data}
249
+ ---
250
+ If you don't know the answer or if the provided knowledge do not contain sufficient information to provide an answer, just say so. Do not make anything up.
251
+ Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
252
+ If you don't know the answer, just say so. Do not make anything up.
253
+ Do not include information where the supporting evidence for it is not provided.
254
+ ---Target response length and format---
255
+ {response_type}
256
+ """
lightrag/storage.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import html
3
+ import json
4
+ import os
5
+ from collections import defaultdict
6
+ from dataclasses import dataclass, field
7
+ from typing import Any, Union, cast
8
+ import pickle
9
+ import hnswlib
10
+ import networkx as nx
11
+ import numpy as np
12
+ from nano_vectordb import NanoVectorDB
13
+ import xxhash
14
+
15
+ from .utils import load_json, logger, write_json
16
+ from .base import (
17
+ BaseGraphStorage,
18
+ BaseKVStorage,
19
+ BaseVectorStorage,
20
+ )
21
+
22
+ @dataclass
23
+ class JsonKVStorage(BaseKVStorage):
24
+ def __post_init__(self):
25
+ working_dir = self.global_config["working_dir"]
26
+ self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
27
+ self._data = load_json(self._file_name) or {}
28
+ logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
29
+
30
+ async def all_keys(self) -> list[str]:
31
+ return list(self._data.keys())
32
+
33
+ async def index_done_callback(self):
34
+ write_json(self._data, self._file_name)
35
+
36
+ async def get_by_id(self, id):
37
+ return self._data.get(id, None)
38
+
39
+ async def get_by_ids(self, ids, fields=None):
40
+ if fields is None:
41
+ return [self._data.get(id, None) for id in ids]
42
+ return [
43
+ (
44
+ {k: v for k, v in self._data[id].items() if k in fields}
45
+ if self._data.get(id, None)
46
+ else None
47
+ )
48
+ for id in ids
49
+ ]
50
+
51
+ async def filter_keys(self, data: list[str]) -> set[str]:
52
+ return set([s for s in data if s not in self._data])
53
+
54
+ async def upsert(self, data: dict[str, dict]):
55
+ left_data = {k: v for k, v in data.items() if k not in self._data}
56
+ self._data.update(left_data)
57
+ return left_data
58
+
59
+ async def drop(self):
60
+ self._data = {}
61
+
62
+ @dataclass
63
+ class NanoVectorDBStorage(BaseVectorStorage):
64
+ cosine_better_than_threshold: float = 0.2
65
+
66
+ def __post_init__(self):
67
+
68
+ self._client_file_name = os.path.join(
69
+ self.global_config["working_dir"], f"vdb_{self.namespace}.json"
70
+ )
71
+ self._max_batch_size = self.global_config["embedding_batch_num"]
72
+ self._client = NanoVectorDB(
73
+ self.embedding_func.embedding_dim, storage_file=self._client_file_name
74
+ )
75
+ self.cosine_better_than_threshold = self.global_config.get(
76
+ "cosine_better_than_threshold", self.cosine_better_than_threshold
77
+ )
78
+
79
+ async def upsert(self, data: dict[str, dict]):
80
+ logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
81
+ if not len(data):
82
+ logger.warning("You insert an empty data to vector DB")
83
+ return []
84
+ list_data = [
85
+ {
86
+ "__id__": k,
87
+ **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
88
+ }
89
+ for k, v in data.items()
90
+ ]
91
+ contents = [v["content"] for v in data.values()]
92
+ batches = [
93
+ contents[i : i + self._max_batch_size]
94
+ for i in range(0, len(contents), self._max_batch_size)
95
+ ]
96
+ embeddings_list = await asyncio.gather(
97
+ *[self.embedding_func(batch) for batch in batches]
98
+ )
99
+ embeddings = np.concatenate(embeddings_list)
100
+ for i, d in enumerate(list_data):
101
+ d["__vector__"] = embeddings[i]
102
+ results = self._client.upsert(datas=list_data)
103
+ return results
104
+
105
+ async def query(self, query: str, top_k=5):
106
+ embedding = await self.embedding_func([query])
107
+ embedding = embedding[0]
108
+ results = self._client.query(
109
+ query=embedding,
110
+ top_k=top_k,
111
+ better_than_threshold=self.cosine_better_than_threshold,
112
+ )
113
+ results = [
114
+ {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
115
+ ]
116
+ return results
117
+
118
+ async def index_done_callback(self):
119
+ self._client.save()
120
+
121
+ @dataclass
122
+ class NetworkXStorage(BaseGraphStorage):
123
+ @staticmethod
124
+ def load_nx_graph(file_name) -> nx.Graph:
125
+ if os.path.exists(file_name):
126
+ return nx.read_graphml(file_name)
127
+ return None
128
+
129
+ @staticmethod
130
+ def write_nx_graph(graph: nx.Graph, file_name):
131
+ logger.info(
132
+ f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
133
+ )
134
+ nx.write_graphml(graph, file_name)
135
+
136
+ @staticmethod
137
+ def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
138
+ """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
139
+ Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
140
+ """
141
+ from graspologic.utils import largest_connected_component
142
+
143
+ graph = graph.copy()
144
+ graph = cast(nx.Graph, largest_connected_component(graph))
145
+ node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} # type: ignore
146
+ graph = nx.relabel_nodes(graph, node_mapping)
147
+ return NetworkXStorage._stabilize_graph(graph)
148
+
149
+ @staticmethod
150
+ def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
151
+ """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
152
+ Ensure an undirected graph with the same relationships will always be read the same way.
153
+ """
154
+ fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
155
+
156
+ sorted_nodes = graph.nodes(data=True)
157
+ sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
158
+
159
+ fixed_graph.add_nodes_from(sorted_nodes)
160
+ edges = list(graph.edges(data=True))
161
+
162
+ if not graph.is_directed():
163
+
164
+ def _sort_source_target(edge):
165
+ source, target, edge_data = edge
166
+ if source > target:
167
+ temp = source
168
+ source = target
169
+ target = temp
170
+ return source, target, edge_data
171
+
172
+ edges = [_sort_source_target(edge) for edge in edges]
173
+
174
+ def _get_edge_key(source: Any, target: Any) -> str:
175
+ return f"{source} -> {target}"
176
+
177
+ edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
178
+
179
+ fixed_graph.add_edges_from(edges)
180
+ return fixed_graph
181
+
182
+ def __post_init__(self):
183
+ self._graphml_xml_file = os.path.join(
184
+ self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
185
+ )
186
+ preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
187
+ if preloaded_graph is not None:
188
+ logger.info(
189
+ f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
190
+ )
191
+ self._graph = preloaded_graph or nx.Graph()
192
+ self._node_embed_algorithms = {
193
+ "node2vec": self._node2vec_embed,
194
+ }
195
+
196
+ async def index_done_callback(self):
197
+ NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
198
+
199
+ async def has_node(self, node_id: str) -> bool:
200
+ return self._graph.has_node(node_id)
201
+
202
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
203
+ return self._graph.has_edge(source_node_id, target_node_id)
204
+
205
+ async def get_node(self, node_id: str) -> Union[dict, None]:
206
+ return self._graph.nodes.get(node_id)
207
+
208
+ async def node_degree(self, node_id: str) -> int:
209
+ return self._graph.degree(node_id)
210
+
211
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
212
+ return self._graph.degree(src_id) + self._graph.degree(tgt_id)
213
+
214
+ async def get_edge(
215
+ self, source_node_id: str, target_node_id: str
216
+ ) -> Union[dict, None]:
217
+ return self._graph.edges.get((source_node_id, target_node_id))
218
+
219
+ async def get_node_edges(self, source_node_id: str):
220
+ if self._graph.has_node(source_node_id):
221
+ return list(self._graph.edges(source_node_id))
222
+ return None
223
+
224
+ async def upsert_node(self, node_id: str, node_data: dict[str, str]):
225
+ self._graph.add_node(node_id, **node_data)
226
+
227
+ async def upsert_edge(
228
+ self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
229
+ ):
230
+ self._graph.add_edge(source_node_id, target_node_id, **edge_data)
231
+
232
+ async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
233
+ if algorithm not in self._node_embed_algorithms:
234
+ raise ValueError(f"Node embedding algorithm {algorithm} not supported")
235
+ return await self._node_embed_algorithms[algorithm]()
236
+
237
+ async def _node2vec_embed(self):
238
+ from graspologic import embed
239
+
240
+ embeddings, nodes = embed.node2vec_embed(
241
+ self._graph,
242
+ **self.global_config["node2vec_params"],
243
+ )
244
+
245
+ nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
246
+ return embeddings, nodes_ids
lightrag/utils.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import html
3
+ import json
4
+ import logging
5
+ import os
6
+ import re
7
+ from dataclasses import dataclass
8
+ from functools import wraps
9
+ from hashlib import md5
10
+ from typing import Any, Union
11
+
12
+ import numpy as np
13
+ import tiktoken
14
+
15
+ ENCODER = None
16
+
17
+ logger = logging.getLogger("lightrag")
18
+
19
+ def set_logger(log_file: str):
20
+ logger.setLevel(logging.DEBUG)
21
+
22
+ file_handler = logging.FileHandler(log_file)
23
+ file_handler.setLevel(logging.DEBUG)
24
+
25
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
26
+ file_handler.setFormatter(formatter)
27
+
28
+ if not logger.handlers:
29
+ logger.addHandler(file_handler)
30
+
31
+ @dataclass
32
+ class EmbeddingFunc:
33
+ embedding_dim: int
34
+ max_token_size: int
35
+ func: callable
36
+
37
+ async def __call__(self, *args, **kwargs) -> np.ndarray:
38
+ return await self.func(*args, **kwargs)
39
+
40
+ def locate_json_string_body_from_string(content: str) -> Union[str, None]:
41
+ """Locate the JSON string body from a string"""
42
+ maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
43
+ if maybe_json_str is not None:
44
+ return maybe_json_str.group(0)
45
+ else:
46
+ return None
47
+
48
+ def convert_response_to_json(response: str) -> dict:
49
+ json_str = locate_json_string_body_from_string(response)
50
+ assert json_str is not None, f"Unable to parse JSON from response: {response}"
51
+ try:
52
+ data = json.loads(json_str)
53
+ return data
54
+ except json.JSONDecodeError as e:
55
+ logger.error(f"Failed to parse JSON: {json_str}")
56
+ raise e from None
57
+
58
+ def compute_args_hash(*args):
59
+ return md5(str(args).encode()).hexdigest()
60
+
61
+ def compute_mdhash_id(content, prefix: str = ""):
62
+ return prefix + md5(content.encode()).hexdigest()
63
+
64
+ def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
65
+ """Add restriction of maximum async calling times for a async func"""
66
+
67
+ def final_decro(func):
68
+ """Not using async.Semaphore to aovid use nest-asyncio"""
69
+ __current_size = 0
70
+
71
+ @wraps(func)
72
+ async def wait_func(*args, **kwargs):
73
+ nonlocal __current_size
74
+ while __current_size >= max_size:
75
+ await asyncio.sleep(waitting_time)
76
+ __current_size += 1
77
+ result = await func(*args, **kwargs)
78
+ __current_size -= 1
79
+ return result
80
+
81
+ return wait_func
82
+
83
+ return final_decro
84
+
85
+ def wrap_embedding_func_with_attrs(**kwargs):
86
+ """Wrap a function with attributes"""
87
+
88
+ def final_decro(func) -> EmbeddingFunc:
89
+ new_func = EmbeddingFunc(**kwargs, func=func)
90
+ return new_func
91
+
92
+ return final_decro
93
+
94
+ def load_json(file_name):
95
+ if not os.path.exists(file_name):
96
+ return None
97
+ with open(file_name) as f:
98
+ return json.load(f)
99
+
100
+ def write_json(json_obj, file_name):
101
+ with open(file_name, "w") as f:
102
+ json.dump(json_obj, f, indent=2, ensure_ascii=False)
103
+
104
+ def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o"):
105
+ global ENCODER
106
+ if ENCODER is None:
107
+ ENCODER = tiktoken.encoding_for_model(model_name)
108
+ tokens = ENCODER.encode(content)
109
+ return tokens
110
+
111
+
112
+ def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o"):
113
+ global ENCODER
114
+ if ENCODER is None:
115
+ ENCODER = tiktoken.encoding_for_model(model_name)
116
+ content = ENCODER.decode(tokens)
117
+ return content
118
+
119
+ def pack_user_ass_to_openai_messages(*args: str):
120
+ roles = ["user", "assistant"]
121
+ return [
122
+ {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
123
+ ]
124
+
125
+ def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
126
+ """Split a string by multiple markers"""
127
+ if not markers:
128
+ return [content]
129
+ results = re.split("|".join(re.escape(marker) for marker in markers), content)
130
+ return [r.strip() for r in results if r.strip()]
131
+
132
+ # Refer the utils functions of the official GraphRAG implementation:
133
+ # https://github.com/microsoft/graphrag
134
+ def clean_str(input: Any) -> str:
135
+ """Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
136
+ # If we get non-string input, just give it back
137
+ if not isinstance(input, str):
138
+ return input
139
+
140
+ result = html.unescape(input.strip())
141
+ # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
142
+ return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
143
+
144
+ def is_float_regex(value):
145
+ return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
146
+
147
+ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
148
+ """Truncate a list of data by token size"""
149
+ if max_token_size <= 0:
150
+ return []
151
+ tokens = 0
152
+ for i, data in enumerate(list_data):
153
+ tokens += len(encode_string_by_tiktoken(key(data)))
154
+ if tokens > max_token_size:
155
+ return list_data[:i]
156
+ return list_data
157
+
158
+ def list_of_list_to_csv(data: list[list]):
159
+ return "\n".join(
160
+ [",\t".join([str(data_dd) for data_dd in data_d]) for data_d in data]
161
+ )
162
+
163
+ def save_data_to_file(data, file_name):
164
+ with open(file_name, 'w', encoding='utf-8') as f:
165
+ json.dump(data, f, ensure_ascii=False, indent=4)