Daniel.y commited on
Commit
4644ee6
·
unverified ·
2 Parent(s): f1a2f89 141cb64

Merge pull request #1753 from HKUDS/rerank

Browse files
README-zh.md CHANGED
@@ -294,6 +294,16 @@ class QueryParam:
294
  top_k: int = int(os.getenv("TOP_K", "60"))
295
  """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
296
 
 
 
 
 
 
 
 
 
 
 
297
  max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
298
  """Maximum number of tokens allowed for each retrieved text chunk."""
299
 
@@ -849,6 +859,18 @@ rag = LightRAG(
849
 
850
  </details>
851
 
 
 
 
 
 
 
 
 
 
 
 
 
852
  ## 编辑实体和关系
853
 
854
  LightRAG现在支持全面的知识图谱管理功能,允许您在知识图谱中创建、编辑和删除实体和关系。
 
294
  top_k: int = int(os.getenv("TOP_K", "60"))
295
  """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
296
 
297
+ chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
298
+ """Number of text chunks to retrieve initially from vector search.
299
+ If None, defaults to top_k value.
300
+ """
301
+
302
+ chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
303
+ """Number of text chunks to keep after reranking.
304
+ If None, keeps all chunks returned from initial retrieval.
305
+ """
306
+
307
  max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
308
  """Maximum number of tokens allowed for each retrieved text chunk."""
309
 
 
859
 
860
  </details>
861
 
862
+ ### LightRAG实例间的数据隔离
863
+
864
+ 通过 workspace 参数可以不同实现不同LightRAG实例之间的存储数据隔离。LightRAG在初始化后workspace就已经确定,之后修改workspace是无效的。下面是不同类型的存储实现工作空间的方式:
865
+
866
+ - **对于本地基于文件的数据库,数据隔离通过工作空间子目录实现:** JsonKVStorage, JsonDocStatusStorage, NetworkXStorage, NanoVectorDBStorage, FaissVectorDBStorage。
867
+ - **对于将数据存储在集合(collection)中的数据库,通过在集合名称前添加工作空间前缀来实现:** RedisKVStorage, RedisDocStatusStorage, MilvusVectorDBStorage, QdrantVectorDBStorage, MongoKVStorage, MongoDocStatusStorage, MongoVectorDBStorage, MongoGraphStorage, PGGraphStorage。
868
+ - **对于关系型数据库,数据隔离通过向表中添加 `workspace` 字段进行数据的逻辑隔离:** PGKVStorage, PGVectorStorage, PGDocStatusStorage。
869
+
870
+ * **对于Neo4j图数据库,通过label来实现数据的逻辑隔离**:Neo4JStorage
871
+
872
+ 为了保持对遗留数据的兼容,在未配置工作空间时PostgreSQL的默认工作空间为`default`,Neo4j的默认工作空间为`base`。对于所有的外部存储,系统都提供了专用的工作空间环境变量,用于覆盖公共的 `WORKSPACE`环境变量配置。这些适用于指定存储类型的工作空间环境变量为:`REDIS_WORKSPACE`, `MILVUS_WORKSPACE`, `QDRANT_WORKSPACE`, `MONGODB_WORKSPACE`, `POSTGRES_WORKSPACE`, `NEO4J_WORKSPACE`。
873
+
874
  ## 编辑实体和关系
875
 
876
  LightRAG现在支持全面的知识图谱管理功能,允许您在知识图谱中创建、编辑和删除实体和关系。
README.md CHANGED
@@ -153,7 +153,7 @@ curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_d
153
  python examples/lightrag_openai_demo.py
154
  ```
155
 
156
- For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample codes LLM and embedding configurations accordingly.
157
 
158
  **Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory.
159
 
@@ -239,6 +239,7 @@ A full list of LightRAG init parameters:
239
  | **Parameter** | **Type** | **Explanation** | **Default** |
240
  |--------------|----------|-----------------|-------------|
241
  | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
 
242
  | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
243
  | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
244
  | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
@@ -300,6 +301,16 @@ class QueryParam:
300
  top_k: int = int(os.getenv("TOP_K", "60"))
301
  """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
302
 
 
 
 
 
 
 
 
 
 
 
303
  max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
304
  """Maximum number of tokens allowed for each retrieved text chunk."""
305
 
@@ -895,6 +906,17 @@ async def initialize_rag():
895
 
896
  </details>
897
 
 
 
 
 
 
 
 
 
 
 
 
898
  ## Edit Entities and Relations
899
 
900
  LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph.
 
153
  python examples/lightrag_openai_demo.py
154
  ```
155
 
156
+ For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code's LLM and embedding configurations accordingly.
157
 
158
  **Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory.
159
 
 
239
  | **Parameter** | **Type** | **Explanation** | **Default** |
240
  |--------------|----------|-----------------|-------------|
241
  | **working_dir** | `str` | Directory where the cache will be stored | `lightrag_cache+timestamp` |
242
+ | **workspace** | str | Workspace name for data isolation between different LightRAG Instances | |
243
  | **kv_storage** | `str` | Storage type for documents and text chunks. Supported types: `JsonKVStorage`,`PGKVStorage`,`RedisKVStorage`,`MongoKVStorage` | `JsonKVStorage` |
244
  | **vector_storage** | `str` | Storage type for embedding vectors. Supported types: `NanoVectorDBStorage`,`PGVectorStorage`,`MilvusVectorDBStorage`,`ChromaVectorDBStorage`,`FaissVectorDBStorage`,`MongoVectorDBStorage`,`QdrantVectorDBStorage` | `NanoVectorDBStorage` |
245
  | **graph_storage** | `str` | Storage type for graph edges and nodes. Supported types: `NetworkXStorage`,`Neo4JStorage`,`PGGraphStorage`,`AGEStorage` | `NetworkXStorage` |
 
301
  top_k: int = int(os.getenv("TOP_K", "60"))
302
  """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
303
 
304
+ chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
305
+ """Number of text chunks to retrieve initially from vector search.
306
+ If None, defaults to top_k value.
307
+ """
308
+
309
+ chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
310
+ """Number of text chunks to keep after reranking.
311
+ If None, keeps all chunks returned from initial retrieval.
312
+ """
313
+
314
  max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
315
  """Maximum number of tokens allowed for each retrieved text chunk."""
316
 
 
906
 
907
  </details>
908
 
909
+ ### Data Isolation Between LightRAG Instances
910
+
911
+ The `workspace` parameter ensures data isolation between different LightRAG instances. Once initialized, the `workspace` is immutable and cannot be changed.Here is how workspaces are implemented for different types of storage:
912
+
913
+ - **For local file-based databases, data isolation is achieved through workspace subdirectories:** `JsonKVStorage`, `JsonDocStatusStorage`, `NetworkXStorage`, `NanoVectorDBStorage`, `FaissVectorDBStorage`.
914
+ - **For databases that store data in collections, it's done by adding a workspace prefix to the collection name:** `RedisKVStorage`, `RedisDocStatusStorage`, `MilvusVectorDBStorage`, `QdrantVectorDBStorage`, `MongoKVStorage`, `MongoDocStatusStorage`, `MongoVectorDBStorage`, `MongoGraphStorage`, `PGGraphStorage`.
915
+ - **For relational databases, data isolation is achieved by adding a `workspace` field to the tables for logical data separation:** `PGKVStorage`, `PGVectorStorage`, `PGDocStatusStorage`.
916
+ - **For the Neo4j graph database, logical data isolation is achieved through labels:** `Neo4JStorage`
917
+
918
+ To maintain compatibility with legacy data, the default workspace for PostgreSQL is `default` and for Neo4j is `base` when no workspace is configured. For all external storages, the system provides dedicated workspace environment variables to override the common `WORKSPACE` environment variable configuration. These storage-specific workspace environment variables are: `REDIS_WORKSPACE`, `MILVUS_WORKSPACE`, `QDRANT_WORKSPACE`, `MONGODB_WORKSPACE`, `POSTGRES_WORKSPACE`, `NEO4J_WORKSPACE`.
919
+
920
  ## Edit Entities and Relations
921
 
922
  LightRAG now supports comprehensive knowledge graph management capabilities, allowing you to create, edit, and delete entities and relationships within your knowledge graph.
docs/rerank_integration.md ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Rerank Integration in LightRAG
2
+
3
+ This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality.
4
+
5
+ ## Overview
6
+
7
+ Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix).
8
+
9
+ ## Architecture
10
+
11
+ The rerank integration follows a simplified design pattern:
12
+
13
+ - **Single Function Configuration**: All rerank settings (model, API keys, top_k, etc.) are contained within the rerank function
14
+ - **Async Processing**: Non-blocking rerank operations
15
+ - **Error Handling**: Graceful fallback to original results
16
+ - **Optional Feature**: Can be enabled/disabled via configuration
17
+ - **Code Reuse**: Single generic implementation for Jina/Cohere compatible APIs
18
+
19
+ ## Configuration
20
+
21
+ ### Environment Variables
22
+
23
+ Set this variable in your `.env` file or environment:
24
+
25
+ ```bash
26
+ # Enable/disable reranking
27
+ ENABLE_RERANK=True
28
+ ```
29
+
30
+ ### Programmatic Configuration
31
+
32
+ ```python
33
+ from lightrag import LightRAG
34
+ from lightrag.rerank import custom_rerank, RerankModel
35
+
36
+ # Method 1: Using a custom rerank function with all settings included
37
+ async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
38
+ return await custom_rerank(
39
+ query=query,
40
+ documents=documents,
41
+ model="BAAI/bge-reranker-v2-m3",
42
+ base_url="https://api.your-provider.com/v1/rerank",
43
+ api_key="your_api_key_here",
44
+ top_k=top_k or 10, # Handle top_k within the function
45
+ **kwargs
46
+ )
47
+
48
+ rag = LightRAG(
49
+ working_dir="./rag_storage",
50
+ llm_model_func=your_llm_func,
51
+ embedding_func=your_embedding_func,
52
+ enable_rerank=True,
53
+ rerank_model_func=my_rerank_func,
54
+ )
55
+
56
+ # Method 2: Using RerankModel wrapper
57
+ rerank_model = RerankModel(
58
+ rerank_func=custom_rerank,
59
+ kwargs={
60
+ "model": "BAAI/bge-reranker-v2-m3",
61
+ "base_url": "https://api.your-provider.com/v1/rerank",
62
+ "api_key": "your_api_key_here",
63
+ }
64
+ )
65
+
66
+ rag = LightRAG(
67
+ working_dir="./rag_storage",
68
+ llm_model_func=your_llm_func,
69
+ embedding_func=your_embedding_func,
70
+ enable_rerank=True,
71
+ rerank_model_func=rerank_model.rerank,
72
+ )
73
+ ```
74
+
75
+ ## Supported Providers
76
+
77
+ ### 1. Custom/Generic API (Recommended)
78
+
79
+ For Jina/Cohere compatible APIs:
80
+
81
+ ```python
82
+ from lightrag.rerank import custom_rerank
83
+
84
+ # Your custom API endpoint
85
+ result = await custom_rerank(
86
+ query="your query",
87
+ documents=documents,
88
+ model="BAAI/bge-reranker-v2-m3",
89
+ base_url="https://api.your-provider.com/v1/rerank",
90
+ api_key="your_api_key_here",
91
+ top_k=10
92
+ )
93
+ ```
94
+
95
+ ### 2. Jina AI
96
+
97
+ ```python
98
+ from lightrag.rerank import jina_rerank
99
+
100
+ result = await jina_rerank(
101
+ query="your query",
102
+ documents=documents,
103
+ model="BAAI/bge-reranker-v2-m3",
104
+ api_key="your_jina_api_key",
105
+ top_k=10
106
+ )
107
+ ```
108
+
109
+ ### 3. Cohere
110
+
111
+ ```python
112
+ from lightrag.rerank import cohere_rerank
113
+
114
+ result = await cohere_rerank(
115
+ query="your query",
116
+ documents=documents,
117
+ model="rerank-english-v2.0",
118
+ api_key="your_cohere_api_key",
119
+ top_k=10
120
+ )
121
+ ```
122
+
123
+ ## Integration Points
124
+
125
+ Reranking is automatically applied at these key retrieval stages:
126
+
127
+ 1. **Naive Mode**: After vector similarity search in `_get_vector_context`
128
+ 2. **Local Mode**: After entity retrieval in `_get_node_data`
129
+ 3. **Global Mode**: After relationship retrieval in `_get_edge_data`
130
+ 4. **Hybrid/Mix Modes**: Applied to all relevant components
131
+
132
+ ## Configuration Parameters
133
+
134
+ | Parameter | Type | Default | Description |
135
+ |-----------|------|---------|-------------|
136
+ | `enable_rerank` | bool | False | Enable/disable reranking |
137
+ | `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_k, etc.) |
138
+
139
+ ## Example Usage
140
+
141
+ ### Basic Usage
142
+
143
+ ```python
144
+ import asyncio
145
+ from lightrag import LightRAG, QueryParam
146
+ from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding
147
+ from lightrag.kg.shared_storage import initialize_pipeline_status
148
+ from lightrag.rerank import jina_rerank
149
+
150
+ async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
151
+ """Custom rerank function with all settings included"""
152
+ return await jina_rerank(
153
+ query=query,
154
+ documents=documents,
155
+ model="BAAI/bge-reranker-v2-m3",
156
+ api_key="your_jina_api_key_here",
157
+ top_k=top_k or 10, # Default top_k if not provided
158
+ **kwargs
159
+ )
160
+
161
+ async def main():
162
+ # Initialize with rerank enabled
163
+ rag = LightRAG(
164
+ working_dir="./rag_storage",
165
+ llm_model_func=gpt_4o_mini_complete,
166
+ embedding_func=openai_embedding,
167
+ enable_rerank=True,
168
+ rerank_model_func=my_rerank_func,
169
+ )
170
+
171
+ await rag.initialize_storages()
172
+ await initialize_pipeline_status()
173
+
174
+ # Insert documents
175
+ await rag.ainsert([
176
+ "Document 1 content...",
177
+ "Document 2 content...",
178
+ ])
179
+
180
+ # Query with rerank (automatically applied)
181
+ result = await rag.aquery(
182
+ "Your question here",
183
+ param=QueryParam(mode="hybrid", top_k=5) # This top_k is passed to rerank function
184
+ )
185
+
186
+ print(result)
187
+
188
+ asyncio.run(main())
189
+ ```
190
+
191
+ ### Direct Rerank Usage
192
+
193
+ ```python
194
+ from lightrag.rerank import custom_rerank
195
+
196
+ async def test_rerank():
197
+ documents = [
198
+ {"content": "Text about topic A"},
199
+ {"content": "Text about topic B"},
200
+ {"content": "Text about topic C"},
201
+ ]
202
+
203
+ reranked = await custom_rerank(
204
+ query="Tell me about topic A",
205
+ documents=documents,
206
+ model="BAAI/bge-reranker-v2-m3",
207
+ base_url="https://api.your-provider.com/v1/rerank",
208
+ api_key="your_api_key_here",
209
+ top_k=2
210
+ )
211
+
212
+ for doc in reranked:
213
+ print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}")
214
+ ```
215
+
216
+ ## Best Practices
217
+
218
+ 1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_k handling) within your rerank function
219
+ 2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff
220
+ 3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function
221
+ 4. **Fallback**: Always handle rerank failures gracefully (returns original results)
222
+ 5. **Top-k Handling**: Handle top_k parameter appropriately within your rerank function
223
+ 6. **Cost Management**: Consider rerank API costs in your budget planning
224
+
225
+ ## Troubleshooting
226
+
227
+ ### Common Issues
228
+
229
+ 1. **API Key Missing**: Ensure API keys are properly configured within your rerank function
230
+ 2. **Network Issues**: Check API endpoints and network connectivity
231
+ 3. **Model Errors**: Verify the rerank model name is supported by your API
232
+ 4. **Document Format**: Ensure documents have `content` or `text` fields
233
+
234
+ ### Debug Mode
235
+
236
+ Enable debug logging to see rerank operations:
237
+
238
+ ```python
239
+ import logging
240
+ logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG)
241
+ ```
242
+
243
+ ### Error Handling
244
+
245
+ The rerank integration includes automatic fallback:
246
+
247
+ ```python
248
+ # If rerank fails, original documents are returned
249
+ # No exceptions are raised to the user
250
+ # Errors are logged for debugging
251
+ ```
252
+
253
+ ## API Compatibility
254
+
255
+ The generic rerank API expects this response format:
256
+
257
+ ```json
258
+ {
259
+ "results": [
260
+ {
261
+ "index": 0,
262
+ "relevance_score": 0.95
263
+ },
264
+ {
265
+ "index": 2,
266
+ "relevance_score": 0.87
267
+ }
268
+ ]
269
+ }
270
+ ```
271
+
272
+ This is compatible with:
273
+ - Jina AI Rerank API
274
+ - Cohere Rerank API
275
+ - Custom APIs following the same format
env.example CHANGED
@@ -42,13 +42,31 @@ OLLAMA_EMULATING_MODEL_TAG=latest
42
  ### Logfile location (defaults to current working directory)
43
  # LOG_DIR=/path/to/log/directory
44
 
45
- ### Settings for RAG query
 
 
 
 
 
 
46
  # HISTORY_TURNS=3
47
- # COSINE_THRESHOLD=0.2
48
- # TOP_K=60
49
- # MAX_TOKEN_TEXT_CHUNK=4000
50
  # MAX_TOKEN_RELATION_DESC=4000
51
  # MAX_TOKEN_ENTITY_DESC=4000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  ### Entity and relation summarization configuration
54
  ### Language: English, Chinese, French, German ...
@@ -62,9 +80,6 @@ SUMMARY_LANGUAGE=English
62
 
63
  ### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended)
64
  # MAX_PARALLEL_INSERT=2
65
- ### Chunk size for document splitting, 500~1500 is recommended
66
- # CHUNK_SIZE=1200
67
- # CHUNK_OVERLAP_SIZE=100
68
 
69
  ### LLM Configuration
70
  ENABLE_LLM_CACHE=true
 
42
  ### Logfile location (defaults to current working directory)
43
  # LOG_DIR=/path/to/log/directory
44
 
45
+ ### RAG Configuration
46
+ ### Chunk size for document splitting, 500~1500 is recommended
47
+ # CHUNK_SIZE=1200
48
+ # CHUNK_OVERLAP_SIZE=100
49
+ # MAX_TOKEN_SUMMARY=500
50
+
51
+ ### RAG Query Configuration
52
  # HISTORY_TURNS=3
53
+ # MAX_TOKEN_TEXT_CHUNK=6000
 
 
54
  # MAX_TOKEN_RELATION_DESC=4000
55
  # MAX_TOKEN_ENTITY_DESC=4000
56
+ # COSINE_THRESHOLD=0.2
57
+ ### Number of entities or relations to retrieve from KG
58
+ # TOP_K=60
59
+ ### Number of text chunks to retrieve initially from vector search
60
+ # CHUNK_TOP_K=5
61
+
62
+ ### Rerank Configuration
63
+ # ENABLE_RERANK=False
64
+ ### Number of text chunks to keep after reranking (should be <= CHUNK_TOP_K)
65
+ # CHUNK_RERANK_TOP_K=5
66
+ ### Rerank model configuration (required when ENABLE_RERANK=True)
67
+ # RERANK_MODEL=BAAI/bge-reranker-v2-m3
68
+ # RERANK_BINDING_HOST=https://api.your-rerank-provider.com/v1/rerank
69
+ # RERANK_BINDING_API_KEY=your_rerank_api_key_here
70
 
71
  ### Entity and relation summarization configuration
72
  ### Language: English, Chinese, French, German ...
 
80
 
81
  ### Number of parallel processing documents(Less than MAX_ASYNC/2 is recommended)
82
  # MAX_PARALLEL_INSERT=2
 
 
 
83
 
84
  ### LLM Configuration
85
  ENABLE_LLM_CACHE=true
examples/rerank_example.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LightRAG Rerank Integration Example
3
+
4
+ This example demonstrates how to use rerank functionality with LightRAG
5
+ to improve retrieval quality across different query modes.
6
+
7
+ Configuration Required:
8
+ 1. Set your LLM API key and base URL in llm_model_func()
9
+ 2. Set your embedding API key and base URL in embedding_func()
10
+ 3. Set your rerank API key and base URL in the rerank configuration
11
+ 4. Or use environment variables (.env file):
12
+ - ENABLE_RERANK=True
13
+ """
14
+
15
+ import asyncio
16
+ import os
17
+ import numpy as np
18
+
19
+ from lightrag import LightRAG, QueryParam
20
+ from lightrag.rerank import custom_rerank, RerankModel
21
+ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
22
+ from lightrag.utils import EmbeddingFunc, setup_logger
23
+ from lightrag.kg.shared_storage import initialize_pipeline_status
24
+
25
+ # Set up your working directory
26
+ WORKING_DIR = "./test_rerank"
27
+ setup_logger("test_rerank")
28
+
29
+ if not os.path.exists(WORKING_DIR):
30
+ os.mkdir(WORKING_DIR)
31
+
32
+
33
+ async def llm_model_func(
34
+ prompt, system_prompt=None, history_messages=[], **kwargs
35
+ ) -> str:
36
+ return await openai_complete_if_cache(
37
+ "gpt-4o-mini",
38
+ prompt,
39
+ system_prompt=system_prompt,
40
+ history_messages=history_messages,
41
+ api_key="your_llm_api_key_here",
42
+ base_url="https://api.your-llm-provider.com/v1",
43
+ **kwargs,
44
+ )
45
+
46
+
47
+ async def embedding_func(texts: list[str]) -> np.ndarray:
48
+ return await openai_embed(
49
+ texts,
50
+ model="text-embedding-3-large",
51
+ api_key="your_embedding_api_key_here",
52
+ base_url="https://api.your-embedding-provider.com/v1",
53
+ )
54
+
55
+
56
+ async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
57
+ """Custom rerank function with all settings included"""
58
+ return await custom_rerank(
59
+ query=query,
60
+ documents=documents,
61
+ model="BAAI/bge-reranker-v2-m3",
62
+ base_url="https://api.your-rerank-provider.com/v1/rerank",
63
+ api_key="your_rerank_api_key_here",
64
+ top_k=top_k or 10, # Default top_k if not provided
65
+ **kwargs,
66
+ )
67
+
68
+
69
+ async def create_rag_with_rerank():
70
+ """Create LightRAG instance with rerank configuration"""
71
+
72
+ # Get embedding dimension
73
+ test_embedding = await embedding_func(["test"])
74
+ embedding_dim = test_embedding.shape[1]
75
+ print(f"Detected embedding dimension: {embedding_dim}")
76
+
77
+ # Method 1: Using custom rerank function
78
+ rag = LightRAG(
79
+ working_dir=WORKING_DIR,
80
+ llm_model_func=llm_model_func,
81
+ embedding_func=EmbeddingFunc(
82
+ embedding_dim=embedding_dim,
83
+ max_token_size=8192,
84
+ func=embedding_func,
85
+ ),
86
+ # Simplified Rerank Configuration
87
+ enable_rerank=True,
88
+ rerank_model_func=my_rerank_func,
89
+ )
90
+
91
+ await rag.initialize_storages()
92
+ await initialize_pipeline_status()
93
+
94
+ return rag
95
+
96
+
97
+ async def create_rag_with_rerank_model():
98
+ """Alternative: Create LightRAG instance using RerankModel wrapper"""
99
+
100
+ # Get embedding dimension
101
+ test_embedding = await embedding_func(["test"])
102
+ embedding_dim = test_embedding.shape[1]
103
+ print(f"Detected embedding dimension: {embedding_dim}")
104
+
105
+ # Method 2: Using RerankModel wrapper
106
+ rerank_model = RerankModel(
107
+ rerank_func=custom_rerank,
108
+ kwargs={
109
+ "model": "BAAI/bge-reranker-v2-m3",
110
+ "base_url": "https://api.your-rerank-provider.com/v1/rerank",
111
+ "api_key": "your_rerank_api_key_here",
112
+ },
113
+ )
114
+
115
+ rag = LightRAG(
116
+ working_dir=WORKING_DIR,
117
+ llm_model_func=llm_model_func,
118
+ embedding_func=EmbeddingFunc(
119
+ embedding_dim=embedding_dim,
120
+ max_token_size=8192,
121
+ func=embedding_func,
122
+ ),
123
+ enable_rerank=True,
124
+ rerank_model_func=rerank_model.rerank,
125
+ )
126
+
127
+ await rag.initialize_storages()
128
+ await initialize_pipeline_status()
129
+
130
+ return rag
131
+
132
+
133
+ async def test_rerank_with_different_topk():
134
+ """
135
+ Test rerank functionality with different top_k settings
136
+ """
137
+ print("🚀 Setting up LightRAG with Rerank functionality...")
138
+
139
+ rag = await create_rag_with_rerank()
140
+
141
+ # Insert sample documents
142
+ sample_docs = [
143
+ "Reranking improves retrieval quality by re-ordering documents based on relevance.",
144
+ "LightRAG is a powerful retrieval-augmented generation system with multiple query modes.",
145
+ "Vector databases enable efficient similarity search in high-dimensional embedding spaces.",
146
+ "Natural language processing has evolved with large language models and transformers.",
147
+ "Machine learning algorithms can learn patterns from data without explicit programming.",
148
+ ]
149
+
150
+ print("📄 Inserting sample documents...")
151
+ await rag.ainsert(sample_docs)
152
+
153
+ query = "How does reranking improve retrieval quality?"
154
+ print(f"\n🔍 Testing query: '{query}'")
155
+ print("=" * 80)
156
+
157
+ # Test different top_k values to show parameter priority
158
+ top_k_values = [2, 5, 10]
159
+
160
+ for top_k in top_k_values:
161
+ print(f"\n📊 Testing with QueryParam(top_k={top_k}):")
162
+
163
+ # Test naive mode with specific top_k
164
+ result = await rag.aquery(query, param=QueryParam(mode="naive", top_k=top_k))
165
+ print(f" Result length: {len(result)} characters")
166
+ print(f" Preview: {result[:100]}...")
167
+
168
+
169
+ async def test_direct_rerank():
170
+ """Test rerank function directly"""
171
+ print("\n🔧 Direct Rerank API Test")
172
+ print("=" * 40)
173
+
174
+ documents = [
175
+ {"content": "Reranking significantly improves retrieval quality"},
176
+ {"content": "LightRAG supports advanced reranking capabilities"},
177
+ {"content": "Vector search finds semantically similar documents"},
178
+ {"content": "Natural language processing with modern transformers"},
179
+ {"content": "The quick brown fox jumps over the lazy dog"},
180
+ ]
181
+
182
+ query = "rerank improve quality"
183
+ print(f"Query: '{query}'")
184
+ print(f"Documents: {len(documents)}")
185
+
186
+ try:
187
+ reranked_docs = await custom_rerank(
188
+ query=query,
189
+ documents=documents,
190
+ model="BAAI/bge-reranker-v2-m3",
191
+ base_url="https://api.your-rerank-provider.com/v1/rerank",
192
+ api_key="your_rerank_api_key_here",
193
+ top_k=3,
194
+ )
195
+
196
+ print("\n✅ Rerank Results:")
197
+ for i, doc in enumerate(reranked_docs):
198
+ score = doc.get("rerank_score", "N/A")
199
+ content = doc.get("content", "")[:60]
200
+ print(f" {i+1}. Score: {score:.4f} | {content}...")
201
+
202
+ except Exception as e:
203
+ print(f"❌ Rerank failed: {e}")
204
+
205
+
206
+ async def main():
207
+ """Main example function"""
208
+ print("🎯 LightRAG Rerank Integration Example")
209
+ print("=" * 60)
210
+
211
+ try:
212
+ # Test rerank with different top_k values
213
+ await test_rerank_with_different_topk()
214
+
215
+ # Test direct rerank
216
+ await test_direct_rerank()
217
+
218
+ print("\n✅ Example completed successfully!")
219
+ print("\n💡 Key Points:")
220
+ print(" ✓ All rerank configurations are contained within rerank_model_func")
221
+ print(" ✓ Rerank improves document relevance ordering")
222
+ print(" ✓ Configure API keys within your rerank function")
223
+ print(" ✓ Monitor API usage and costs when using rerank services")
224
+
225
+ except Exception as e:
226
+ print(f"\n❌ Example failed: {e}")
227
+ import traceback
228
+
229
+ traceback.print_exc()
230
+
231
+
232
+ if __name__ == "__main__":
233
+ asyncio.run(main())
lightrag/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
2
 
3
- __version__ = "1.3.10"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
 
1
  from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
2
 
3
+ __version__ = "1.4.0"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
lightrag/api/config.py CHANGED
@@ -165,6 +165,24 @@ def parse_args() -> argparse.Namespace:
165
  default=get_env_value("TOP_K", 60, int),
166
  help="Number of most similar results to return (default: from env or 60)",
167
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  parser.add_argument(
169
  "--cosine-threshold",
170
  type=float,
@@ -295,6 +313,11 @@ def parse_args() -> argparse.Namespace:
295
  args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
296
  args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
297
 
 
 
 
 
 
298
  ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
299
 
300
  return args
 
165
  default=get_env_value("TOP_K", 60, int),
166
  help="Number of most similar results to return (default: from env or 60)",
167
  )
168
+ parser.add_argument(
169
+ "--chunk-top-k",
170
+ type=int,
171
+ default=get_env_value("CHUNK_TOP_K", 15, int),
172
+ help="Number of text chunks to retrieve initially from vector search (default: from env or 15)",
173
+ )
174
+ parser.add_argument(
175
+ "--chunk-rerank-top-k",
176
+ type=int,
177
+ default=get_env_value("CHUNK_RERANK_TOP_K", 5, int),
178
+ help="Number of text chunks to keep after reranking (default: from env or 5)",
179
+ )
180
+ parser.add_argument(
181
+ "--enable-rerank",
182
+ action="store_true",
183
+ default=get_env_value("ENABLE_RERANK", False, bool),
184
+ help="Enable rerank functionality (default: from env or False)",
185
+ )
186
  parser.add_argument(
187
  "--cosine-threshold",
188
  type=float,
 
313
  args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
314
  args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
315
 
316
+ # Rerank model configuration
317
+ args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
318
+ args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
319
+ args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
320
+
321
  ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
322
 
323
  return args
lightrag/api/lightrag_server.py CHANGED
@@ -291,6 +291,32 @@ def create_app(args):
291
  ),
292
  )
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  # Initialize RAG
295
  if args.llm_binding in ["lollms", "ollama", "openai"]:
296
  rag = LightRAG(
@@ -324,6 +350,8 @@ def create_app(args):
324
  },
325
  enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
326
  enable_llm_cache=args.enable_llm_cache,
 
 
327
  auto_manage_storages_states=False,
328
  max_parallel_insert=args.max_parallel_insert,
329
  max_graph_nodes=args.max_graph_nodes,
@@ -352,6 +380,8 @@ def create_app(args):
352
  },
353
  enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
354
  enable_llm_cache=args.enable_llm_cache,
 
 
355
  auto_manage_storages_states=False,
356
  max_parallel_insert=args.max_parallel_insert,
357
  max_graph_nodes=args.max_graph_nodes,
@@ -478,6 +508,12 @@ def create_app(args):
478
  "enable_llm_cache": args.enable_llm_cache,
479
  "workspace": args.workspace,
480
  "max_graph_nodes": args.max_graph_nodes,
 
 
 
 
 
 
481
  },
482
  "auth_mode": auth_mode,
483
  "pipeline_busy": pipeline_status.get("busy", False),
 
291
  ),
292
  )
293
 
294
+ # Configure rerank function if enabled
295
+ rerank_model_func = None
296
+ if args.enable_rerank and args.rerank_binding_api_key and args.rerank_binding_host:
297
+ from lightrag.rerank import custom_rerank
298
+
299
+ async def server_rerank_func(
300
+ query: str, documents: list, top_k: int = None, **kwargs
301
+ ):
302
+ """Server rerank function with configuration from environment variables"""
303
+ return await custom_rerank(
304
+ query=query,
305
+ documents=documents,
306
+ model=args.rerank_model,
307
+ base_url=args.rerank_binding_host,
308
+ api_key=args.rerank_binding_api_key,
309
+ top_k=top_k,
310
+ **kwargs,
311
+ )
312
+
313
+ rerank_model_func = server_rerank_func
314
+ logger.info(f"Rerank enabled with model: {args.rerank_model}")
315
+ elif args.enable_rerank:
316
+ logger.warning(
317
+ "Rerank enabled but RERANK_BINDING_API_KEY or RERANK_BINDING_HOST not configured. Rerank will be disabled."
318
+ )
319
+
320
  # Initialize RAG
321
  if args.llm_binding in ["lollms", "ollama", "openai"]:
322
  rag = LightRAG(
 
350
  },
351
  enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
352
  enable_llm_cache=args.enable_llm_cache,
353
+ enable_rerank=args.enable_rerank,
354
+ rerank_model_func=rerank_model_func,
355
  auto_manage_storages_states=False,
356
  max_parallel_insert=args.max_parallel_insert,
357
  max_graph_nodes=args.max_graph_nodes,
 
380
  },
381
  enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
382
  enable_llm_cache=args.enable_llm_cache,
383
+ enable_rerank=args.enable_rerank,
384
+ rerank_model_func=rerank_model_func,
385
  auto_manage_storages_states=False,
386
  max_parallel_insert=args.max_parallel_insert,
387
  max_graph_nodes=args.max_graph_nodes,
 
508
  "enable_llm_cache": args.enable_llm_cache,
509
  "workspace": args.workspace,
510
  "max_graph_nodes": args.max_graph_nodes,
511
+ # Rerank configuration
512
+ "enable_rerank": args.enable_rerank,
513
+ "rerank_model": args.rerank_model if args.enable_rerank else None,
514
+ "rerank_binding_host": args.rerank_binding_host
515
+ if args.enable_rerank
516
+ else None,
517
  },
518
  "auth_mode": auth_mode,
519
  "pipeline_busy": pipeline_status.get("busy", False),
lightrag/api/routers/query_routes.py CHANGED
@@ -49,6 +49,18 @@ class QueryRequest(BaseModel):
49
  description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
50
  )
51
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  max_token_for_text_unit: Optional[int] = Field(
53
  gt=1,
54
  default=None,
 
49
  description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
50
  )
51
 
52
+ chunk_top_k: Optional[int] = Field(
53
+ ge=1,
54
+ default=None,
55
+ description="Number of text chunks to retrieve initially from vector search.",
56
+ )
57
+
58
+ chunk_rerank_top_k: Optional[int] = Field(
59
+ ge=1,
60
+ default=None,
61
+ description="Number of text chunks to keep after reranking.",
62
+ )
63
+
64
  max_token_for_text_unit: Optional[int] = Field(
65
  gt=1,
66
  default=None,
lightrag/base.py CHANGED
@@ -60,7 +60,17 @@ class QueryParam:
60
  top_k: int = int(os.getenv("TOP_K", "60"))
61
  """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
62
 
63
- max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
 
 
 
 
 
 
 
 
 
 
64
  """Maximum number of tokens allowed for each retrieved text chunk."""
65
 
66
  max_token_for_global_context: int = int(
@@ -280,21 +290,6 @@ class BaseKVStorage(StorageNameSpace, ABC):
280
  False: if the cache drop failed, or the cache mode is not supported
281
  """
282
 
283
- # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
284
- # """Delete specific cache records from storage by chunk IDs
285
-
286
- # Importance notes for in-memory storage:
287
- # 1. Changes will be persisted to disk during the next index_done_callback
288
- # 2. update flags to notify other processes that data persistence is needed
289
-
290
- # Args:
291
- # chunk_ids (list[str]): List of chunk IDs to be dropped from storage
292
-
293
- # Returns:
294
- # True: if the cache drop successfully
295
- # False: if the cache drop failed, or the operation is not supported
296
- # """
297
-
298
 
299
  @dataclass
300
  class BaseGraphStorage(StorageNameSpace, ABC):
 
60
  top_k: int = int(os.getenv("TOP_K", "60"))
61
  """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
62
 
63
+ chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
64
+ """Number of text chunks to retrieve initially from vector search.
65
+ If None, defaults to top_k value.
66
+ """
67
+
68
+ chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
69
+ """Number of text chunks to keep after reranking.
70
+ If None, keeps all chunks returned from initial retrieval.
71
+ """
72
+
73
+ max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "6000"))
74
  """Maximum number of tokens allowed for each retrieved text chunk."""
75
 
76
  max_token_for_global_context: int = int(
 
290
  False: if the cache drop failed, or the cache mode is not supported
291
  """
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  @dataclass
295
  class BaseGraphStorage(StorageNameSpace, ABC):
lightrag/lightrag.py CHANGED
@@ -240,6 +240,17 @@ class LightRAG:
240
  llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
241
  """Additional keyword arguments passed to the LLM model function."""
242
 
 
 
 
 
 
 
 
 
 
 
 
243
  # Storage
244
  # ---
245
 
@@ -447,6 +458,14 @@ class LightRAG:
447
  )
448
  )
449
 
 
 
 
 
 
 
 
 
450
  self._storages_status = StoragesStatus.CREATED
451
 
452
  if self.auto_manage_storages_states:
 
240
  llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
241
  """Additional keyword arguments passed to the LLM model function."""
242
 
243
+ # Rerank Configuration
244
+ # ---
245
+
246
+ enable_rerank: bool = field(
247
+ default=bool(os.getenv("ENABLE_RERANK", "False").lower() == "true")
248
+ )
249
+ """Enable reranking for improved retrieval quality. Defaults to False."""
250
+
251
+ rerank_model_func: Callable[..., object] | None = field(default=None)
252
+ """Function for reranking retrieved documents. All rerank configurations (model name, API keys, top_k, etc.) should be included in this function. Optional."""
253
+
254
  # Storage
255
  # ---
256
 
 
458
  )
459
  )
460
 
461
+ # Init Rerank
462
+ if self.enable_rerank and self.rerank_model_func:
463
+ logger.info("Rerank model initialized for improved retrieval quality")
464
+ elif self.enable_rerank and not self.rerank_model_func:
465
+ logger.warning(
466
+ "Rerank is enabled but no rerank_model_func provided. Reranking will be skipped."
467
+ )
468
+
469
  self._storages_status = StoragesStatus.CREATED
470
 
471
  if self.auto_manage_storages_states:
lightrag/operate.py CHANGED
@@ -1527,6 +1527,7 @@ async def kg_query(
1527
 
1528
  # Build context
1529
  context = await _build_query_context(
 
1530
  ll_keywords_str,
1531
  hl_keywords_str,
1532
  knowledge_graph_inst,
@@ -1746,84 +1747,52 @@ async def _get_vector_context(
1746
  query: str,
1747
  chunks_vdb: BaseVectorStorage,
1748
  query_param: QueryParam,
1749
- tokenizer: Tokenizer,
1750
- ) -> tuple[list, list, list] | None:
1751
  """
1752
- Retrieve vector context from the vector database.
1753
 
1754
- This function performs vector search to find relevant text chunks for a query,
1755
- formats them with file path and creation time information.
1756
 
1757
  Args:
1758
  query: The query string to search for
1759
  chunks_vdb: Vector database containing document chunks
1760
- query_param: Query parameters including top_k and ids
1761
- tokenizer: Tokenizer for counting tokens
1762
 
1763
  Returns:
1764
- Tuple (empty_entities, empty_relations, text_units) for combine_contexts,
1765
- compatible with _get_edge_data and _get_node_data format
1766
  """
1767
  try:
1768
- results = await chunks_vdb.query(
1769
- query, top_k=query_param.top_k, ids=query_param.ids
1770
- )
 
1771
  if not results:
1772
- return [], [], []
1773
 
1774
  valid_chunks = []
1775
  for result in results:
1776
  if "content" in result:
1777
- # Directly use content from chunks_vdb.query result
1778
- chunk_with_time = {
1779
  "content": result["content"],
1780
  "created_at": result.get("created_at", None),
1781
  "file_path": result.get("file_path", "unknown_source"),
 
1782
  }
1783
- valid_chunks.append(chunk_with_time)
1784
-
1785
- if not valid_chunks:
1786
- return [], [], []
1787
-
1788
- maybe_trun_chunks = truncate_list_by_token_size(
1789
- valid_chunks,
1790
- key=lambda x: x["content"],
1791
- max_token_size=query_param.max_token_for_text_unit,
1792
- tokenizer=tokenizer,
1793
- )
1794
 
1795
- logger.debug(
1796
- f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
1797
- )
1798
  logger.info(
1799
- f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
1800
  )
 
1801
 
1802
- if not maybe_trun_chunks:
1803
- return [], [], []
1804
-
1805
- # Create empty entities and relations contexts
1806
- entities_context = []
1807
- relations_context = []
1808
-
1809
- # Create text_units_context directly as a list of dictionaries
1810
- text_units_context = []
1811
- for i, chunk in enumerate(maybe_trun_chunks):
1812
- text_units_context.append(
1813
- {
1814
- "id": i + 1,
1815
- "content": chunk["content"],
1816
- "file_path": chunk["file_path"],
1817
- }
1818
- )
1819
-
1820
- return entities_context, relations_context, text_units_context
1821
  except Exception as e:
1822
  logger.error(f"Error in _get_vector_context: {e}")
1823
- return [], [], []
1824
 
1825
 
1826
  async def _build_query_context(
 
1827
  ll_keywords: str,
1828
  hl_keywords: str,
1829
  knowledge_graph_inst: BaseGraphStorage,
@@ -1831,27 +1800,36 @@ async def _build_query_context(
1831
  relationships_vdb: BaseVectorStorage,
1832
  text_chunks_db: BaseKVStorage,
1833
  query_param: QueryParam,
1834
- chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode
1835
  ):
1836
  logger.info(f"Process {os.getpid()} building query context...")
1837
 
1838
- # Handle local and global modes as before
 
 
 
 
 
1839
  if query_param.mode == "local":
1840
- entities_context, relations_context, text_units_context = await _get_node_data(
1841
  ll_keywords,
1842
  knowledge_graph_inst,
1843
  entities_vdb,
1844
  text_chunks_db,
1845
  query_param,
1846
  )
 
 
1847
  elif query_param.mode == "global":
1848
- entities_context, relations_context, text_units_context = await _get_edge_data(
1849
  hl_keywords,
1850
  knowledge_graph_inst,
1851
  relationships_vdb,
1852
  text_chunks_db,
1853
  query_param,
1854
  )
 
 
1855
  else: # hybrid or mix mode
1856
  ll_data = await _get_node_data(
1857
  ll_keywords,
@@ -1868,61 +1846,58 @@ async def _build_query_context(
1868
  query_param,
1869
  )
1870
 
1871
- (
1872
- ll_entities_context,
1873
- ll_relations_context,
1874
- ll_text_units_context,
1875
- ) = ll_data
1876
-
1877
- (
1878
- hl_entities_context,
1879
- hl_relations_context,
1880
- hl_text_units_context,
1881
- ) = hl_data
1882
-
1883
- # Initialize vector data with empty lists
1884
- vector_entities_context, vector_relations_context, vector_text_units_context = (
1885
- [],
1886
- [],
1887
- [],
1888
- )
1889
 
1890
- # Only get vector data if in mix mode
1891
- if query_param.mode == "mix" and hasattr(query_param, "original_query"):
1892
- # Get tokenizer from text_chunks_db
1893
- tokenizer = text_chunks_db.global_config.get("tokenizer")
1894
 
1895
- # Get vector context in triple format
1896
- vector_data = await _get_vector_context(
1897
- query_param.original_query, # We need to pass the original query
 
1898
  chunks_vdb,
1899
  query_param,
1900
- tokenizer,
1901
  )
 
1902
 
1903
- # If vector_data is not None, unpack it
1904
- if vector_data is not None:
1905
- (
1906
- vector_entities_context,
1907
- vector_relations_context,
1908
- vector_text_units_context,
1909
- ) = vector_data
1910
-
1911
- # Combine and deduplicate the entities, relationships, and sources
1912
  entities_context = process_combine_contexts(
1913
- hl_entities_context, ll_entities_context, vector_entities_context
1914
  )
1915
  relations_context = process_combine_contexts(
1916
- hl_relations_context, ll_relations_context, vector_relations_context
1917
  )
1918
- text_units_context = process_combine_contexts(
1919
- hl_text_units_context, ll_text_units_context, vector_text_units_context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1920
  )
 
 
 
 
 
1921
  # not necessary to use LLM to generate a response
1922
  if not entities_context and not relations_context:
1923
  return None
1924
 
1925
- # 转换为 JSON 字符串
1926
  entities_str = json.dumps(entities_context, ensure_ascii=False)
1927
  relations_str = json.dumps(relations_context, ensure_ascii=False)
1928
  text_units_str = json.dumps(text_units_context, ensure_ascii=False)
@@ -2069,16 +2044,7 @@ async def _get_node_data(
2069
  }
2070
  )
2071
 
2072
- text_units_context = []
2073
- for i, t in enumerate(use_text_units):
2074
- text_units_context.append(
2075
- {
2076
- "id": i + 1,
2077
- "content": t["content"],
2078
- "file_path": t.get("file_path", "unknown_source"),
2079
- }
2080
- )
2081
- return entities_context, relations_context, text_units_context
2082
 
2083
 
2084
  async def _find_most_related_text_unit_from_entities(
@@ -2167,23 +2133,21 @@ async def _find_most_related_text_unit_from_entities(
2167
  logger.warning("No valid text units found")
2168
  return []
2169
 
2170
- tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
2171
  all_text_units = sorted(
2172
  all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
2173
  )
2174
- all_text_units = truncate_list_by_token_size(
2175
- all_text_units,
2176
- key=lambda x: x["data"]["content"],
2177
- max_token_size=query_param.max_token_for_text_unit,
2178
- tokenizer=tokenizer,
2179
- )
2180
 
2181
- logger.debug(
2182
- f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
2183
- )
2184
 
2185
- all_text_units = [t["data"] for t in all_text_units]
2186
- return all_text_units
 
 
 
 
 
 
2187
 
2188
 
2189
  async def _find_most_related_edges_from_entities(
@@ -2485,21 +2449,16 @@ async def _find_related_text_unit_from_relationships(
2485
  logger.warning("No valid text chunks after filtering")
2486
  return []
2487
 
2488
- tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
2489
- truncated_text_units = truncate_list_by_token_size(
2490
- valid_text_units,
2491
- key=lambda x: x["data"]["content"],
2492
- max_token_size=query_param.max_token_for_text_unit,
2493
- tokenizer=tokenizer,
2494
- )
2495
-
2496
- logger.debug(
2497
- f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
2498
- )
2499
 
2500
- all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
 
 
 
 
 
2501
 
2502
- return all_text_units
2503
 
2504
 
2505
  async def naive_query(
@@ -2527,13 +2486,33 @@ async def naive_query(
2527
 
2528
  tokenizer: Tokenizer = global_config["tokenizer"]
2529
 
2530
- _, _, text_units_context = await _get_vector_context(
2531
- query, chunks_vdb, query_param, tokenizer
2532
- )
2533
 
2534
- if text_units_context is None or len(text_units_context) == 0:
2535
  return PROMPTS["fail_response"]
2536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2537
  text_units_str = json.dumps(text_units_context, ensure_ascii=False)
2538
  if query_param.only_need_context:
2539
  return f"""
@@ -2658,6 +2637,7 @@ async def kg_query_with_keywords(
2658
  hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
2659
 
2660
  context = await _build_query_context(
 
2661
  ll_keywords_str,
2662
  hl_keywords_str,
2663
  knowledge_graph_inst,
@@ -2780,8 +2760,6 @@ async def query_with_keywords(
2780
  f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
2781
  )
2782
 
2783
- param.original_query = query
2784
-
2785
  # Use appropriate query method based on mode
2786
  if param.mode in ["local", "global", "hybrid", "mix"]:
2787
  return await kg_query_with_keywords(
@@ -2808,3 +2786,131 @@ async def query_with_keywords(
2808
  )
2809
  else:
2810
  raise ValueError(f"Unknown mode {param.mode}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1527
 
1528
  # Build context
1529
  context = await _build_query_context(
1530
+ query,
1531
  ll_keywords_str,
1532
  hl_keywords_str,
1533
  knowledge_graph_inst,
 
1747
  query: str,
1748
  chunks_vdb: BaseVectorStorage,
1749
  query_param: QueryParam,
1750
+ ) -> list[dict]:
 
1751
  """
1752
+ Retrieve text chunks from the vector database without reranking or truncation.
1753
 
1754
+ This function performs vector search to find relevant text chunks for a query.
1755
+ Reranking and truncation will be handled later in the unified processing.
1756
 
1757
  Args:
1758
  query: The query string to search for
1759
  chunks_vdb: Vector database containing document chunks
1760
+ query_param: Query parameters including chunk_top_k and ids
 
1761
 
1762
  Returns:
1763
+ List of text chunks with metadata
 
1764
  """
1765
  try:
1766
+ # Use chunk_top_k if specified, otherwise fall back to top_k
1767
+ search_top_k = query_param.chunk_top_k or query_param.top_k
1768
+
1769
+ results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids)
1770
  if not results:
1771
+ return []
1772
 
1773
  valid_chunks = []
1774
  for result in results:
1775
  if "content" in result:
1776
+ chunk_with_metadata = {
 
1777
  "content": result["content"],
1778
  "created_at": result.get("created_at", None),
1779
  "file_path": result.get("file_path", "unknown_source"),
1780
+ "source_type": "vector", # Mark the source type
1781
  }
1782
+ valid_chunks.append(chunk_with_metadata)
 
 
 
 
 
 
 
 
 
 
1783
 
 
 
 
1784
  logger.info(
1785
+ f"Naive query: {len(valid_chunks)} chunks (chunk_top_k: {search_top_k})"
1786
  )
1787
+ return valid_chunks
1788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1789
  except Exception as e:
1790
  logger.error(f"Error in _get_vector_context: {e}")
1791
+ return []
1792
 
1793
 
1794
  async def _build_query_context(
1795
+ query: str,
1796
  ll_keywords: str,
1797
  hl_keywords: str,
1798
  knowledge_graph_inst: BaseGraphStorage,
 
1800
  relationships_vdb: BaseVectorStorage,
1801
  text_chunks_db: BaseKVStorage,
1802
  query_param: QueryParam,
1803
+ chunks_vdb: BaseVectorStorage = None,
1804
  ):
1805
  logger.info(f"Process {os.getpid()} building query context...")
1806
 
1807
+ # Collect all chunks from different sources
1808
+ all_chunks = []
1809
+ entities_context = []
1810
+ relations_context = []
1811
+
1812
+ # Handle local and global modes
1813
  if query_param.mode == "local":
1814
+ entities_context, relations_context, entity_chunks = await _get_node_data(
1815
  ll_keywords,
1816
  knowledge_graph_inst,
1817
  entities_vdb,
1818
  text_chunks_db,
1819
  query_param,
1820
  )
1821
+ all_chunks.extend(entity_chunks)
1822
+
1823
  elif query_param.mode == "global":
1824
+ entities_context, relations_context, relationship_chunks = await _get_edge_data(
1825
  hl_keywords,
1826
  knowledge_graph_inst,
1827
  relationships_vdb,
1828
  text_chunks_db,
1829
  query_param,
1830
  )
1831
+ all_chunks.extend(relationship_chunks)
1832
+
1833
  else: # hybrid or mix mode
1834
  ll_data = await _get_node_data(
1835
  ll_keywords,
 
1846
  query_param,
1847
  )
1848
 
1849
+ (ll_entities_context, ll_relations_context, ll_chunks) = ll_data
1850
+ (hl_entities_context, hl_relations_context, hl_chunks) = hl_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1851
 
1852
+ # Collect chunks from entity and relationship sources
1853
+ all_chunks.extend(ll_chunks)
1854
+ all_chunks.extend(hl_chunks)
 
1855
 
1856
+ # Get vector chunks if in mix mode
1857
+ if query_param.mode == "mix" and chunks_vdb:
1858
+ vector_chunks = await _get_vector_context(
1859
+ query,
1860
  chunks_vdb,
1861
  query_param,
 
1862
  )
1863
+ all_chunks.extend(vector_chunks)
1864
 
1865
+ # Combine entities and relations contexts
 
 
 
 
 
 
 
 
1866
  entities_context = process_combine_contexts(
1867
+ hl_entities_context, ll_entities_context
1868
  )
1869
  relations_context = process_combine_contexts(
1870
+ hl_relations_context, ll_relations_context
1871
  )
1872
+
1873
+ # Process all chunks uniformly: deduplication, reranking, and token truncation
1874
+ processed_chunks = await process_chunks_unified(
1875
+ query=query,
1876
+ chunks=all_chunks,
1877
+ query_param=query_param,
1878
+ global_config=text_chunks_db.global_config,
1879
+ source_type="mixed",
1880
+ )
1881
+
1882
+ # Build final text_units_context from processed chunks
1883
+ text_units_context = []
1884
+ for i, chunk in enumerate(processed_chunks):
1885
+ text_units_context.append(
1886
+ {
1887
+ "id": i + 1,
1888
+ "content": chunk["content"],
1889
+ "file_path": chunk.get("file_path", "unknown_source"),
1890
+ }
1891
  )
1892
+
1893
+ logger.info(
1894
+ f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks"
1895
+ )
1896
+
1897
  # not necessary to use LLM to generate a response
1898
  if not entities_context and not relations_context:
1899
  return None
1900
 
 
1901
  entities_str = json.dumps(entities_context, ensure_ascii=False)
1902
  relations_str = json.dumps(relations_context, ensure_ascii=False)
1903
  text_units_str = json.dumps(text_units_context, ensure_ascii=False)
 
2044
  }
2045
  )
2046
 
2047
+ return entities_context, relations_context, use_text_units
 
 
 
 
 
 
 
 
 
2048
 
2049
 
2050
  async def _find_most_related_text_unit_from_entities(
 
2133
  logger.warning("No valid text units found")
2134
  return []
2135
 
2136
+ # Sort by relation counts and order, but don't truncate
2137
  all_text_units = sorted(
2138
  all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
2139
  )
 
 
 
 
 
 
2140
 
2141
+ logger.debug(f"Found {len(all_text_units)} entity-related chunks")
 
 
2142
 
2143
+ # Add source type marking and return chunk data
2144
+ result_chunks = []
2145
+ for t in all_text_units:
2146
+ chunk_data = t["data"].copy()
2147
+ chunk_data["source_type"] = "entity"
2148
+ result_chunks.append(chunk_data)
2149
+
2150
+ return result_chunks
2151
 
2152
 
2153
  async def _find_most_related_edges_from_entities(
 
2449
  logger.warning("No valid text chunks after filtering")
2450
  return []
2451
 
2452
+ logger.debug(f"Found {len(valid_text_units)} relationship-related chunks")
 
 
 
 
 
 
 
 
 
 
2453
 
2454
+ # Add source type marking and return chunk data
2455
+ result_chunks = []
2456
+ for t in valid_text_units:
2457
+ chunk_data = t["data"].copy()
2458
+ chunk_data["source_type"] = "relationship"
2459
+ result_chunks.append(chunk_data)
2460
 
2461
+ return result_chunks
2462
 
2463
 
2464
  async def naive_query(
 
2486
 
2487
  tokenizer: Tokenizer = global_config["tokenizer"]
2488
 
2489
+ chunks = await _get_vector_context(query, chunks_vdb, query_param)
 
 
2490
 
2491
+ if chunks is None or len(chunks) == 0:
2492
  return PROMPTS["fail_response"]
2493
 
2494
+ # Process chunks using unified processing
2495
+ processed_chunks = await process_chunks_unified(
2496
+ query=query,
2497
+ chunks=chunks,
2498
+ query_param=query_param,
2499
+ global_config=global_config,
2500
+ source_type="vector",
2501
+ )
2502
+
2503
+ logger.info(f"Final context: {len(processed_chunks)} chunks")
2504
+
2505
+ # Build text_units_context from processed chunks
2506
+ text_units_context = []
2507
+ for i, chunk in enumerate(processed_chunks):
2508
+ text_units_context.append(
2509
+ {
2510
+ "id": i + 1,
2511
+ "content": chunk["content"],
2512
+ "file_path": chunk.get("file_path", "unknown_source"),
2513
+ }
2514
+ )
2515
+
2516
  text_units_str = json.dumps(text_units_context, ensure_ascii=False)
2517
  if query_param.only_need_context:
2518
  return f"""
 
2637
  hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
2638
 
2639
  context = await _build_query_context(
2640
+ query,
2641
  ll_keywords_str,
2642
  hl_keywords_str,
2643
  knowledge_graph_inst,
 
2760
  f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
2761
  )
2762
 
 
 
2763
  # Use appropriate query method based on mode
2764
  if param.mode in ["local", "global", "hybrid", "mix"]:
2765
  return await kg_query_with_keywords(
 
2786
  )
2787
  else:
2788
  raise ValueError(f"Unknown mode {param.mode}")
2789
+
2790
+
2791
+ async def apply_rerank_if_enabled(
2792
+ query: str,
2793
+ retrieved_docs: list[dict],
2794
+ global_config: dict,
2795
+ top_k: int = None,
2796
+ ) -> list[dict]:
2797
+ """
2798
+ Apply reranking to retrieved documents if rerank is enabled.
2799
+
2800
+ Args:
2801
+ query: The search query
2802
+ retrieved_docs: List of retrieved documents
2803
+ global_config: Global configuration containing rerank settings
2804
+ top_k: Number of top documents to return after reranking
2805
+
2806
+ Returns:
2807
+ Reranked documents if rerank is enabled, otherwise original documents
2808
+ """
2809
+ if not global_config.get("enable_rerank", False) or not retrieved_docs:
2810
+ return retrieved_docs
2811
+
2812
+ rerank_func = global_config.get("rerank_model_func")
2813
+ if not rerank_func:
2814
+ logger.debug(
2815
+ "Rerank is enabled but no rerank function provided, skipping rerank"
2816
+ )
2817
+ return retrieved_docs
2818
+
2819
+ try:
2820
+ logger.debug(
2821
+ f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_k}"
2822
+ )
2823
+
2824
+ # Apply reranking - let rerank_model_func handle top_k internally
2825
+ reranked_docs = await rerank_func(
2826
+ query=query,
2827
+ documents=retrieved_docs,
2828
+ top_k=top_k,
2829
+ )
2830
+ if reranked_docs and len(reranked_docs) > 0:
2831
+ if len(reranked_docs) > top_k:
2832
+ reranked_docs = reranked_docs[:top_k]
2833
+ logger.info(
2834
+ f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
2835
+ )
2836
+ return reranked_docs
2837
+ else:
2838
+ logger.warning("Rerank returned empty results, using original documents")
2839
+ return retrieved_docs
2840
+
2841
+ except Exception as e:
2842
+ logger.error(f"Error during reranking: {e}, using original documents")
2843
+ return retrieved_docs
2844
+
2845
+
2846
+ async def process_chunks_unified(
2847
+ query: str,
2848
+ chunks: list[dict],
2849
+ query_param: QueryParam,
2850
+ global_config: dict,
2851
+ source_type: str = "mixed",
2852
+ ) -> list[dict]:
2853
+ """
2854
+ Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation.
2855
+
2856
+ Args:
2857
+ query: Search query for reranking
2858
+ chunks: List of text chunks to process
2859
+ query_param: Query parameters containing configuration
2860
+ global_config: Global configuration dictionary
2861
+ source_type: Source type for logging ("vector", "entity", "relationship", "mixed")
2862
+
2863
+ Returns:
2864
+ Processed and filtered list of text chunks
2865
+ """
2866
+ if not chunks:
2867
+ return []
2868
+
2869
+ # 1. Deduplication based on content
2870
+ seen_content = set()
2871
+ unique_chunks = []
2872
+ for chunk in chunks:
2873
+ content = chunk.get("content", "")
2874
+ if content and content not in seen_content:
2875
+ seen_content.add(content)
2876
+ unique_chunks.append(chunk)
2877
+
2878
+ logger.debug(
2879
+ f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
2880
+ )
2881
+
2882
+ # 2. Apply reranking if enabled and query is provided
2883
+ if global_config.get("enable_rerank", False) and query and unique_chunks:
2884
+ rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
2885
+ unique_chunks = await apply_rerank_if_enabled(
2886
+ query=query,
2887
+ retrieved_docs=unique_chunks,
2888
+ global_config=global_config,
2889
+ top_k=rerank_top_k,
2890
+ )
2891
+ logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
2892
+
2893
+ # 3. Apply chunk_top_k limiting if specified
2894
+ if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0:
2895
+ if len(unique_chunks) > query_param.chunk_top_k:
2896
+ unique_chunks = unique_chunks[: query_param.chunk_top_k]
2897
+ logger.debug(
2898
+ f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})"
2899
+ )
2900
+
2901
+ # 4. Token-based final truncation
2902
+ tokenizer = global_config.get("tokenizer")
2903
+ if tokenizer and unique_chunks:
2904
+ original_count = len(unique_chunks)
2905
+ unique_chunks = truncate_list_by_token_size(
2906
+ unique_chunks,
2907
+ key=lambda x: x.get("content", ""),
2908
+ max_token_size=query_param.max_token_for_text_unit,
2909
+ tokenizer=tokenizer,
2910
+ )
2911
+ logger.debug(
2912
+ f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
2913
+ f"(max tokens: {query_param.max_token_for_text_unit}, source: {source_type})"
2914
+ )
2915
+
2916
+ return unique_chunks
lightrag/rerank.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import aiohttp
5
+ from typing import Callable, Any, List, Dict, Optional
6
+ from pydantic import BaseModel, Field
7
+
8
+ from .utils import logger
9
+
10
+
11
+ class RerankModel(BaseModel):
12
+ """
13
+ Pydantic model class for defining a custom rerank model.
14
+
15
+ This class provides a convenient wrapper for rerank functions, allowing you to
16
+ encapsulate all rerank configurations (API keys, model settings, etc.) in one place.
17
+
18
+ Attributes:
19
+ rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents.
20
+ The function should take query and documents as input and return reranked results.
21
+ kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
22
+ This should include all necessary configurations such as model name, API key, base_url, etc.
23
+
24
+ Example usage:
25
+ Rerank model example with Jina:
26
+ ```python
27
+ rerank_model = RerankModel(
28
+ rerank_func=jina_rerank,
29
+ kwargs={
30
+ "model": "BAAI/bge-reranker-v2-m3",
31
+ "api_key": "your_api_key_here",
32
+ "base_url": "https://api.jina.ai/v1/rerank"
33
+ }
34
+ )
35
+
36
+ # Use in LightRAG
37
+ rag = LightRAG(
38
+ enable_rerank=True,
39
+ rerank_model_func=rerank_model.rerank,
40
+ # ... other configurations
41
+ )
42
+ ```
43
+
44
+ Or define a custom function directly:
45
+ ```python
46
+ async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
47
+ return await jina_rerank(
48
+ query=query,
49
+ documents=documents,
50
+ model="BAAI/bge-reranker-v2-m3",
51
+ api_key="your_api_key_here",
52
+ top_k=top_k or 10,
53
+ **kwargs
54
+ )
55
+
56
+ rag = LightRAG(
57
+ enable_rerank=True,
58
+ rerank_model_func=my_rerank_func,
59
+ # ... other configurations
60
+ )
61
+ ```
62
+ """
63
+
64
+ rerank_func: Callable[[Any], List[Dict]]
65
+ kwargs: Dict[str, Any] = Field(default_factory=dict)
66
+
67
+ async def rerank(
68
+ self,
69
+ query: str,
70
+ documents: List[Dict[str, Any]],
71
+ top_k: Optional[int] = None,
72
+ **extra_kwargs,
73
+ ) -> List[Dict[str, Any]]:
74
+ """Rerank documents using the configured model function."""
75
+ # Merge extra kwargs with model kwargs
76
+ kwargs = {**self.kwargs, **extra_kwargs}
77
+ return await self.rerank_func(
78
+ query=query, documents=documents, top_k=top_k, **kwargs
79
+ )
80
+
81
+
82
+ class MultiRerankModel(BaseModel):
83
+ """Multiple rerank models for different modes/scenarios."""
84
+
85
+ # Primary rerank model (used if mode-specific models are not defined)
86
+ rerank_model: Optional[RerankModel] = None
87
+
88
+ # Mode-specific rerank models
89
+ entity_rerank_model: Optional[RerankModel] = None
90
+ relation_rerank_model: Optional[RerankModel] = None
91
+ chunk_rerank_model: Optional[RerankModel] = None
92
+
93
+ async def rerank(
94
+ self,
95
+ query: str,
96
+ documents: List[Dict[str, Any]],
97
+ mode: str = "default",
98
+ top_k: Optional[int] = None,
99
+ **kwargs,
100
+ ) -> List[Dict[str, Any]]:
101
+ """Rerank using the appropriate model based on mode."""
102
+
103
+ # Select model based on mode
104
+ if mode == "entity" and self.entity_rerank_model:
105
+ model = self.entity_rerank_model
106
+ elif mode == "relation" and self.relation_rerank_model:
107
+ model = self.relation_rerank_model
108
+ elif mode == "chunk" and self.chunk_rerank_model:
109
+ model = self.chunk_rerank_model
110
+ elif self.rerank_model:
111
+ model = self.rerank_model
112
+ else:
113
+ logger.warning(f"No rerank model available for mode: {mode}")
114
+ return documents
115
+
116
+ return await model.rerank(query, documents, top_k, **kwargs)
117
+
118
+
119
+ async def generic_rerank_api(
120
+ query: str,
121
+ documents: List[Dict[str, Any]],
122
+ model: str,
123
+ base_url: str,
124
+ api_key: str,
125
+ top_k: Optional[int] = None,
126
+ **kwargs,
127
+ ) -> List[Dict[str, Any]]:
128
+ """
129
+ Generic rerank function that works with Jina/Cohere compatible APIs.
130
+
131
+ Args:
132
+ query: The search query
133
+ documents: List of documents to rerank
134
+ model: Model identifier
135
+ base_url: API endpoint URL
136
+ api_key: API authentication key
137
+ top_k: Number of top results to return
138
+ **kwargs: Additional API-specific parameters
139
+
140
+ Returns:
141
+ List of reranked documents with relevance scores
142
+ """
143
+ if not api_key:
144
+ logger.warning("No API key provided for rerank service")
145
+ return documents
146
+
147
+ if not documents:
148
+ return documents
149
+
150
+ # Prepare documents for reranking - handle both text and dict formats
151
+ prepared_docs = []
152
+ for doc in documents:
153
+ if isinstance(doc, dict):
154
+ # Use 'content' field if available, otherwise use 'text' or convert to string
155
+ text = doc.get("content") or doc.get("text") or str(doc)
156
+ else:
157
+ text = str(doc)
158
+ prepared_docs.append(text)
159
+
160
+ # Prepare request
161
+ headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
162
+
163
+ data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
164
+
165
+ if top_k is not None:
166
+ data["top_k"] = min(top_k, len(prepared_docs))
167
+
168
+ try:
169
+ async with aiohttp.ClientSession() as session:
170
+ async with session.post(base_url, headers=headers, json=data) as response:
171
+ if response.status != 200:
172
+ error_text = await response.text()
173
+ logger.error(f"Rerank API error {response.status}: {error_text}")
174
+ return documents
175
+
176
+ result = await response.json()
177
+
178
+ # Extract reranked results
179
+ if "results" in result:
180
+ # Standard format: results contain index and relevance_score
181
+ reranked_docs = []
182
+ for item in result["results"]:
183
+ if "index" in item:
184
+ doc_idx = item["index"]
185
+ if 0 <= doc_idx < len(documents):
186
+ reranked_doc = documents[doc_idx].copy()
187
+ if "relevance_score" in item:
188
+ reranked_doc["rerank_score"] = item[
189
+ "relevance_score"
190
+ ]
191
+ reranked_docs.append(reranked_doc)
192
+ return reranked_docs
193
+ else:
194
+ logger.warning("Unexpected rerank API response format")
195
+ return documents
196
+
197
+ except Exception as e:
198
+ logger.error(f"Error during reranking: {e}")
199
+ return documents
200
+
201
+
202
+ async def jina_rerank(
203
+ query: str,
204
+ documents: List[Dict[str, Any]],
205
+ model: str = "BAAI/bge-reranker-v2-m3",
206
+ top_k: Optional[int] = None,
207
+ base_url: str = "https://api.jina.ai/v1/rerank",
208
+ api_key: Optional[str] = None,
209
+ **kwargs,
210
+ ) -> List[Dict[str, Any]]:
211
+ """
212
+ Rerank documents using Jina AI API.
213
+
214
+ Args:
215
+ query: The search query
216
+ documents: List of documents to rerank
217
+ model: Jina rerank model name
218
+ top_k: Number of top results to return
219
+ base_url: Jina API endpoint
220
+ api_key: Jina API key
221
+ **kwargs: Additional parameters
222
+
223
+ Returns:
224
+ List of reranked documents with relevance scores
225
+ """
226
+ if api_key is None:
227
+ api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
228
+
229
+ return await generic_rerank_api(
230
+ query=query,
231
+ documents=documents,
232
+ model=model,
233
+ base_url=base_url,
234
+ api_key=api_key,
235
+ top_k=top_k,
236
+ **kwargs,
237
+ )
238
+
239
+
240
+ async def cohere_rerank(
241
+ query: str,
242
+ documents: List[Dict[str, Any]],
243
+ model: str = "rerank-english-v2.0",
244
+ top_k: Optional[int] = None,
245
+ base_url: str = "https://api.cohere.ai/v1/rerank",
246
+ api_key: Optional[str] = None,
247
+ **kwargs,
248
+ ) -> List[Dict[str, Any]]:
249
+ """
250
+ Rerank documents using Cohere API.
251
+
252
+ Args:
253
+ query: The search query
254
+ documents: List of documents to rerank
255
+ model: Cohere rerank model name
256
+ top_k: Number of top results to return
257
+ base_url: Cohere API endpoint
258
+ api_key: Cohere API key
259
+ **kwargs: Additional parameters
260
+
261
+ Returns:
262
+ List of reranked documents with relevance scores
263
+ """
264
+ if api_key is None:
265
+ api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
266
+
267
+ return await generic_rerank_api(
268
+ query=query,
269
+ documents=documents,
270
+ model=model,
271
+ base_url=base_url,
272
+ api_key=api_key,
273
+ top_k=top_k,
274
+ **kwargs,
275
+ )
276
+
277
+
278
+ # Convenience function for custom API endpoints
279
+ async def custom_rerank(
280
+ query: str,
281
+ documents: List[Dict[str, Any]],
282
+ model: str,
283
+ base_url: str,
284
+ api_key: str,
285
+ top_k: Optional[int] = None,
286
+ **kwargs,
287
+ ) -> List[Dict[str, Any]]:
288
+ """
289
+ Rerank documents using a custom API endpoint.
290
+ This is useful for self-hosted or custom rerank services.
291
+ """
292
+ return await generic_rerank_api(
293
+ query=query,
294
+ documents=documents,
295
+ model=model,
296
+ base_url=base_url,
297
+ api_key=api_key,
298
+ top_k=top_k,
299
+ **kwargs,
300
+ )
301
+
302
+
303
+ if __name__ == "__main__":
304
+ import asyncio
305
+
306
+ async def main():
307
+ # Example usage
308
+ docs = [
309
+ {"content": "The capital of France is Paris."},
310
+ {"content": "Tokyo is the capital of Japan."},
311
+ {"content": "London is the capital of England."},
312
+ ]
313
+
314
+ query = "What is the capital of France?"
315
+
316
+ result = await jina_rerank(
317
+ query=query, documents=docs, top_k=2, api_key="your-api-key-here"
318
+ )
319
+ print(result)
320
+
321
+ asyncio.run(main())
lightrag_webui/src/stores/settings.ts CHANGED
@@ -111,7 +111,7 @@ const useSettingsStoreBase = create<SettingsState>()(
111
  mode: 'global',
112
  response_type: 'Multiple Paragraphs',
113
  top_k: 10,
114
- max_token_for_text_unit: 4000,
115
  max_token_for_global_context: 4000,
116
  max_token_for_local_context: 4000,
117
  only_need_context: false,
 
111
  mode: 'global',
112
  response_type: 'Multiple Paragraphs',
113
  top_k: 10,
114
+ max_token_for_text_unit: 6000,
115
  max_token_for_global_context: 4000,
116
  max_token_for_local_context: 4000,
117
  only_need_context: false,