zrguo commited on
Commit
fc091ae
·
1 Parent(s): 8089f33

add rerank model

Browse files
docs/rerank_integration.md ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Rerank Integration in LightRAG
2
+
3
+ This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality.
4
+
5
+ ## ⚠️ Important: Parameter Priority
6
+
7
+ **QueryParam.top_k has higher priority than rerank_top_k configuration:**
8
+
9
+ - When you set `QueryParam(top_k=5)`, it will override the `rerank_top_k=10` setting in LightRAG configuration
10
+ - This means the actual number of documents sent to rerank will be determined by QueryParam.top_k
11
+ - For optimal rerank performance, always consider the top_k value in your QueryParam calls
12
+ - Example: `rag.aquery(query, param=QueryParam(mode="naive", top_k=20))` will use 20, not rerank_top_k
13
+
14
+ ## Overview
15
+
16
+ Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix).
17
+
18
+ ## Architecture
19
+
20
+ The rerank integration follows the same design pattern as the LLM integration:
21
+
22
+ - **Configurable Models**: Support for multiple rerank providers through a generic API
23
+ - **Async Processing**: Non-blocking rerank operations
24
+ - **Error Handling**: Graceful fallback to original results
25
+ - **Optional Feature**: Can be enabled/disabled via configuration
26
+ - **Code Reuse**: Single generic implementation for Jina/Cohere compatible APIs
27
+
28
+ ## Configuration
29
+
30
+ ### Environment Variables
31
+
32
+ Set these variables in your `.env` file or environment:
33
+
34
+ ```bash
35
+ # Enable/disable reranking
36
+ ENABLE_RERANK=True
37
+
38
+ # Rerank model configuration
39
+ RERANK_MODEL=BAAI/bge-reranker-v2-m3
40
+ RERANK_MAX_ASYNC=4
41
+ RERANK_TOP_K=10
42
+
43
+ # API configuration
44
+ RERANK_API_KEY=your_rerank_api_key_here
45
+ RERANK_BASE_URL=https://api.your-provider.com/v1/rerank
46
+
47
+ # Provider-specific keys (optional alternatives)
48
+ JINA_API_KEY=your_jina_api_key_here
49
+ COHERE_API_KEY=your_cohere_api_key_here
50
+ ```
51
+
52
+ ### Programmatic Configuration
53
+
54
+ ```python
55
+ from lightrag import LightRAG
56
+ from lightrag.rerank import custom_rerank, RerankModel
57
+
58
+ # Method 1: Using environment variables (recommended)
59
+ rag = LightRAG(
60
+ working_dir="./rag_storage",
61
+ llm_model_func=your_llm_func,
62
+ embedding_func=your_embedding_func,
63
+ # Rerank automatically configured from environment variables
64
+ )
65
+
66
+ # Method 2: Explicit configuration
67
+ rerank_model = RerankModel(
68
+ rerank_func=custom_rerank,
69
+ kwargs={
70
+ "model": "BAAI/bge-reranker-v2-m3",
71
+ "base_url": "https://api.your-provider.com/v1/rerank",
72
+ "api_key": "your_api_key_here",
73
+ }
74
+ )
75
+
76
+ rag = LightRAG(
77
+ working_dir="./rag_storage",
78
+ llm_model_func=your_llm_func,
79
+ embedding_func=your_embedding_func,
80
+ enable_rerank=True,
81
+ rerank_model_func=rerank_model.rerank,
82
+ rerank_top_k=10,
83
+ )
84
+ ```
85
+
86
+ ## Supported Providers
87
+
88
+ ### 1. Custom/Generic API (Recommended)
89
+
90
+ For Jina/Cohere compatible APIs:
91
+
92
+ ```python
93
+ from lightrag.rerank import custom_rerank
94
+
95
+ # Your custom API endpoint
96
+ result = await custom_rerank(
97
+ query="your query",
98
+ documents=documents,
99
+ model="BAAI/bge-reranker-v2-m3",
100
+ base_url="https://api.your-provider.com/v1/rerank",
101
+ api_key="your_api_key_here",
102
+ top_k=10
103
+ )
104
+ ```
105
+
106
+ ### 2. Jina AI
107
+
108
+ ```python
109
+ from lightrag.rerank import jina_rerank
110
+
111
+ result = await jina_rerank(
112
+ query="your query",
113
+ documents=documents,
114
+ model="BAAI/bge-reranker-v2-m3",
115
+ api_key="your_jina_api_key"
116
+ )
117
+ ```
118
+
119
+ ### 3. Cohere
120
+
121
+ ```python
122
+ from lightrag.rerank import cohere_rerank
123
+
124
+ result = await cohere_rerank(
125
+ query="your query",
126
+ documents=documents,
127
+ model="rerank-english-v2.0",
128
+ api_key="your_cohere_api_key"
129
+ )
130
+ ```
131
+
132
+ ## Integration Points
133
+
134
+ Reranking is automatically applied at these key retrieval stages:
135
+
136
+ 1. **Naive Mode**: After vector similarity search in `_get_vector_context`
137
+ 2. **Local Mode**: After entity retrieval in `_get_node_data`
138
+ 3. **Global Mode**: After relationship retrieval in `_get_edge_data`
139
+ 4. **Hybrid/Mix Modes**: Applied to all relevant components
140
+
141
+ ## Configuration Parameters
142
+
143
+ | Parameter | Type | Default | Description |
144
+ |-----------|------|---------|-------------|
145
+ | `enable_rerank` | bool | False | Enable/disable reranking |
146
+ | `rerank_model_name` | str | "BAAI/bge-reranker-v2-m3" | Model identifier |
147
+ | `rerank_model_max_async` | int | 4 | Max concurrent rerank calls |
148
+ | `rerank_top_k` | int | 10 | Number of top results to return ⚠️ **Overridden by QueryParam.top_k** |
149
+ | `rerank_model_func` | callable | None | Custom rerank function |
150
+ | `rerank_model_kwargs` | dict | {} | Additional rerank parameters |
151
+
152
+ ## Example Usage
153
+
154
+ ### Basic Usage
155
+
156
+ ```python
157
+ import asyncio
158
+ from lightrag import LightRAG, QueryParam
159
+ from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding
160
+
161
+ async def main():
162
+ # Initialize with rerank enabled
163
+ rag = LightRAG(
164
+ working_dir="./rag_storage",
165
+ llm_model_func=gpt_4o_mini_complete,
166
+ embedding_func=openai_embedding,
167
+ enable_rerank=True,
168
+ )
169
+
170
+ # Insert documents
171
+ await rag.ainsert([
172
+ "Document 1 content...",
173
+ "Document 2 content...",
174
+ ])
175
+
176
+ # Query with rerank (automatically applied)
177
+ result = await rag.aquery(
178
+ "Your question here",
179
+ param=QueryParam(mode="hybrid", top_k=5) # ⚠️ This top_k=5 overrides rerank_top_k
180
+ )
181
+
182
+ print(result)
183
+
184
+ asyncio.run(main())
185
+ ```
186
+
187
+ ### Direct Rerank Usage
188
+
189
+ ```python
190
+ from lightrag.rerank import custom_rerank
191
+
192
+ async def test_rerank():
193
+ documents = [
194
+ {"content": "Text about topic A"},
195
+ {"content": "Text about topic B"},
196
+ {"content": "Text about topic C"},
197
+ ]
198
+
199
+ reranked = await custom_rerank(
200
+ query="Tell me about topic A",
201
+ documents=documents,
202
+ model="BAAI/bge-reranker-v2-m3",
203
+ base_url="https://api.your-provider.com/v1/rerank",
204
+ api_key="your_api_key_here",
205
+ top_k=2
206
+ )
207
+
208
+ for doc in reranked:
209
+ print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}")
210
+ ```
211
+
212
+ ## Best Practices
213
+
214
+ 1. **Parameter Priority Awareness**: Remember that QueryParam.top_k always overrides rerank_top_k configuration
215
+ 2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff
216
+ 3. **API Limits**: Monitor API usage and implement rate limiting if needed
217
+ 4. **Fallback**: Always handle rerank failures gracefully (returns original results)
218
+ 5. **Top-k Selection**: Choose appropriate `top_k` values in QueryParam based on your use case
219
+ 6. **Cost Management**: Consider rerank API costs in your budget planning
220
+
221
+ ## Troubleshooting
222
+
223
+ ### Common Issues
224
+
225
+ 1. **API Key Missing**: Ensure `RERANK_API_KEY` or provider-specific keys are set
226
+ 2. **Network Issues**: Check `RERANK_BASE_URL` and network connectivity
227
+ 3. **Model Errors**: Verify the rerank model name is supported by your API
228
+ 4. **Document Format**: Ensure documents have `content` or `text` fields
229
+
230
+ ### Debug Mode
231
+
232
+ Enable debug logging to see rerank operations:
233
+
234
+ ```python
235
+ import logging
236
+ logging.getLogger("lightrag.rerank").setLevel(logging.DEBUG)
237
+ ```
238
+
239
+ ### Error Handling
240
+
241
+ The rerank integration includes automatic fallback:
242
+
243
+ ```python
244
+ # If rerank fails, original documents are returned
245
+ # No exceptions are raised to the user
246
+ # Errors are logged for debugging
247
+ ```
248
+
249
+ ## API Compatibility
250
+
251
+ The generic rerank API expects this response format:
252
+
253
+ ```json
254
+ {
255
+ "results": [
256
+ {
257
+ "index": 0,
258
+ "relevance_score": 0.95
259
+ },
260
+ {
261
+ "index": 2,
262
+ "relevance_score": 0.87
263
+ }
264
+ ]
265
+ }
266
+ ```
267
+
268
+ This is compatible with:
269
+ - Jina AI Rerank API
270
+ - Cohere Rerank API
271
+ - Custom APIs following the same format
env.example CHANGED
@@ -179,3 +179,14 @@ QDRANT_URL=http://localhost:6333
179
  ### Redis
180
  REDIS_URI=redis://localhost:6379
181
  # REDIS_WORKSPACE=forced_workspace_name
 
 
 
 
 
 
 
 
 
 
 
 
179
  ### Redis
180
  REDIS_URI=redis://localhost:6379
181
  # REDIS_WORKSPACE=forced_workspace_name
182
+
183
+ # Rerank Configuration
184
+ ENABLE_RERANK=False
185
+ RERANK_MODEL=BAAI/bge-reranker-v2-m3
186
+ RERANK_MAX_ASYNC=4
187
+ RERANK_TOP_K=10
188
+ # Note: QueryParam.top_k in your code will override RERANK_TOP_K setting
189
+
190
+ # Rerank API Configuration
191
+ RERANK_API_KEY=your_rerank_api_key_here
192
+ RERANK_BASE_URL=https://api.your-provider.com/v1/rerank
examples/rerank_example.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LightRAG Rerank Integration Example
3
+
4
+ This example demonstrates how to use rerank functionality with LightRAG
5
+ to improve retrieval quality across different query modes.
6
+
7
+ IMPORTANT: Parameter Priority
8
+ - QueryParam(top_k=N) has higher priority than rerank_top_k in LightRAG configuration
9
+ - If you set QueryParam(top_k=5), it will override rerank_top_k setting
10
+ - For optimal rerank performance, use appropriate top_k values in QueryParam
11
+
12
+ Configuration Required:
13
+ 1. Set your LLM API key and base URL in llm_model_func()
14
+ 2. Set your embedding API key and base URL in embedding_func()
15
+ 3. Set your rerank API key and base URL in the rerank configuration
16
+ 4. Or use environment variables (.env file):
17
+ - RERANK_API_KEY=your_actual_rerank_api_key
18
+ - RERANK_BASE_URL=https://your-actual-rerank-endpoint/v1/rerank
19
+ - RERANK_MODEL=your_rerank_model_name
20
+ """
21
+
22
+ import asyncio
23
+ import os
24
+ import numpy as np
25
+
26
+ from lightrag import LightRAG, QueryParam
27
+ from lightrag.rerank import custom_rerank, RerankModel
28
+ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
29
+ from lightrag.utils import EmbeddingFunc, setup_logger
30
+
31
+ # Set up your working directory
32
+ WORKING_DIR = "./test_rerank"
33
+ setup_logger("test_rerank")
34
+
35
+ if not os.path.exists(WORKING_DIR):
36
+ os.mkdir(WORKING_DIR)
37
+
38
+ async def llm_model_func(
39
+ prompt, system_prompt=None, history_messages=[], **kwargs
40
+ ) -> str:
41
+ return await openai_complete_if_cache(
42
+ "gpt-4o-mini",
43
+ prompt,
44
+ system_prompt=system_prompt,
45
+ history_messages=history_messages,
46
+ api_key="your_llm_api_key_here",
47
+ base_url="https://api.your-llm-provider.com/v1",
48
+ **kwargs,
49
+ )
50
+
51
+ async def embedding_func(texts: list[str]) -> np.ndarray:
52
+ return await openai_embed(
53
+ texts,
54
+ model="text-embedding-3-large",
55
+ api_key="your_embedding_api_key_here",
56
+ base_url="https://api.your-embedding-provider.com/v1",
57
+ )
58
+
59
+ async def create_rag_with_rerank():
60
+ """Create LightRAG instance with rerank configuration"""
61
+
62
+ # Get embedding dimension
63
+ test_embedding = await embedding_func(["test"])
64
+ embedding_dim = test_embedding.shape[1]
65
+ print(f"Detected embedding dimension: {embedding_dim}")
66
+
67
+ # Create rerank model
68
+ rerank_model = RerankModel(
69
+ rerank_func=custom_rerank,
70
+ kwargs={
71
+ "model": "BAAI/bge-reranker-v2-m3",
72
+ "base_url": "https://api.your-rerank-provider.com/v1/rerank",
73
+ "api_key": "your_rerank_api_key_here",
74
+ }
75
+ )
76
+
77
+ # Initialize LightRAG with rerank
78
+ rag = LightRAG(
79
+ working_dir=WORKING_DIR,
80
+ llm_model_func=llm_model_func,
81
+ embedding_func=EmbeddingFunc(
82
+ embedding_dim=embedding_dim,
83
+ max_token_size=8192,
84
+ func=embedding_func,
85
+ ),
86
+ # Rerank Configuration
87
+ enable_rerank=True,
88
+ rerank_model_func=rerank_model.rerank,
89
+ rerank_top_k=10, # Note: QueryParam.top_k will override this
90
+ )
91
+
92
+ return rag
93
+
94
+ async def test_rerank_with_different_topk():
95
+ """
96
+ Test rerank functionality with different top_k settings to demonstrate parameter priority
97
+ """
98
+ print("🚀 Setting up LightRAG with Rerank functionality...")
99
+
100
+ rag = await create_rag_with_rerank()
101
+
102
+ # Insert sample documents
103
+ sample_docs = [
104
+ "Reranking improves retrieval quality by re-ordering documents based on relevance.",
105
+ "LightRAG is a powerful retrieval-augmented generation system with multiple query modes.",
106
+ "Vector databases enable efficient similarity search in high-dimensional embedding spaces.",
107
+ "Natural language processing has evolved with large language models and transformers.",
108
+ "Machine learning algorithms can learn patterns from data without explicit programming."
109
+ ]
110
+
111
+ print("📄 Inserting sample documents...")
112
+ await rag.ainsert(sample_docs)
113
+
114
+ query = "How does reranking improve retrieval quality?"
115
+ print(f"\n🔍 Testing query: '{query}'")
116
+ print("=" * 80)
117
+
118
+ # Test different top_k values to show parameter priority
119
+ top_k_values = [2, 5, 10]
120
+
121
+ for top_k in top_k_values:
122
+ print(f"\n📊 Testing with QueryParam(top_k={top_k}) - overrides rerank_top_k=10:")
123
+
124
+ # Test naive mode with specific top_k
125
+ result = await rag.aquery(
126
+ query,
127
+ param=QueryParam(mode="naive", top_k=top_k)
128
+ )
129
+ print(f" Result length: {len(result)} characters")
130
+ print(f" Preview: {result[:100]}...")
131
+
132
+ async def test_direct_rerank():
133
+ """Test rerank function directly"""
134
+ print("\n🔧 Direct Rerank API Test")
135
+ print("=" * 40)
136
+
137
+ documents = [
138
+ {"content": "Reranking significantly improves retrieval quality"},
139
+ {"content": "LightRAG supports advanced reranking capabilities"},
140
+ {"content": "Vector search finds semantically similar documents"},
141
+ {"content": "Natural language processing with modern transformers"},
142
+ {"content": "The quick brown fox jumps over the lazy dog"}
143
+ ]
144
+
145
+ query = "rerank improve quality"
146
+ print(f"Query: '{query}'")
147
+ print(f"Documents: {len(documents)}")
148
+
149
+ try:
150
+ reranked_docs = await custom_rerank(
151
+ query=query,
152
+ documents=documents,
153
+ model="BAAI/bge-reranker-v2-m3",
154
+ base_url="https://api.your-rerank-provider.com/v1/rerank",
155
+ api_key="your_rerank_api_key_here",
156
+ top_k=3
157
+ )
158
+
159
+ print("\n✅ Rerank Results:")
160
+ for i, doc in enumerate(reranked_docs):
161
+ score = doc.get("rerank_score", "N/A")
162
+ content = doc.get("content", "")[:60]
163
+ print(f" {i+1}. Score: {score:.4f} | {content}...")
164
+
165
+ except Exception as e:
166
+ print(f"❌ Rerank failed: {e}")
167
+
168
+ async def main():
169
+ """Main example function"""
170
+ print("🎯 LightRAG Rerank Integration Example")
171
+ print("=" * 60)
172
+
173
+ try:
174
+ # Test rerank with different top_k values
175
+ await test_rerank_with_different_topk()
176
+
177
+ # Test direct rerank
178
+ await test_direct_rerank()
179
+
180
+ print("\n✅ Example completed successfully!")
181
+ print("\n💡 Key Points:")
182
+ print(" ✓ QueryParam.top_k has higher priority than rerank_top_k")
183
+ print(" ✓ Rerank improves document relevance ordering")
184
+ print(" ✓ Configure API keys in your .env file for production")
185
+ print(" ✓ Monitor API usage and costs when using rerank services")
186
+
187
+ except Exception as e:
188
+ print(f"\n❌ Example failed: {e}")
189
+ import traceback
190
+ traceback.print_exc()
191
+
192
+ if __name__ == "__main__":
193
+ asyncio.run(main())
lightrag/lightrag.py CHANGED
@@ -240,6 +240,35 @@ class LightRAG:
240
  llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
241
  """Additional keyword arguments passed to the LLM model function."""
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  # Storage
244
  # ---
245
 
@@ -444,6 +473,22 @@ class LightRAG:
444
  )
445
  )
446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  self._storages_status = StoragesStatus.CREATED
448
 
449
  if self.auto_manage_storages_states:
 
240
  llm_model_kwargs: dict[str, Any] = field(default_factory=dict)
241
  """Additional keyword arguments passed to the LLM model function."""
242
 
243
+ # Rerank Configuration
244
+ # ---
245
+
246
+ enable_rerank: bool = field(
247
+ default=bool(os.getenv("ENABLE_RERANK", "False").lower() == "true")
248
+ )
249
+ """Enable reranking for improved retrieval quality. Defaults to False."""
250
+
251
+ rerank_model_func: Callable[..., object] | None = field(default=None)
252
+ """Function for reranking retrieved documents. Optional."""
253
+
254
+ rerank_model_name: str = field(
255
+ default=os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
256
+ )
257
+ """Name of the rerank model used for reranking documents."""
258
+
259
+ rerank_model_max_async: int = field(default=int(os.getenv("RERANK_MAX_ASYNC", 4)))
260
+ """Maximum number of concurrent rerank calls."""
261
+
262
+ rerank_model_kwargs: dict[str, Any] = field(default_factory=dict)
263
+ """Additional keyword arguments passed to the rerank model function."""
264
+
265
+ rerank_top_k: int = field(default=int(os.getenv("RERANK_TOP_K", 10)))
266
+ """Number of top documents to return after reranking.
267
+
268
+ Note: This value will be overridden by QueryParam.top_k in query calls.
269
+ Example: QueryParam(top_k=5) will override rerank_top_k=10 setting.
270
+ """
271
+
272
  # Storage
273
  # ---
274
 
 
473
  )
474
  )
475
 
476
+ # Init Rerank
477
+ if self.enable_rerank and self.rerank_model_func:
478
+ self.rerank_model_func = priority_limit_async_func_call(
479
+ self.rerank_model_max_async
480
+ )(
481
+ partial(
482
+ self.rerank_model_func, # type: ignore
483
+ **self.rerank_model_kwargs,
484
+ )
485
+ )
486
+ logger.info("Rerank model initialized for improved retrieval quality")
487
+ elif self.enable_rerank and not self.rerank_model_func:
488
+ logger.warning(
489
+ "Rerank is enabled but no rerank_model_func provided. Reranking will be skipped."
490
+ )
491
+
492
  self._storages_status = StoragesStatus.CREATED
493
 
494
  if self.auto_manage_storages_states:
lightrag/operate.py CHANGED
@@ -1783,6 +1783,15 @@ async def _get_vector_context(
1783
  if not valid_chunks:
1784
  return [], [], []
1785
 
 
 
 
 
 
 
 
 
 
1786
  maybe_trun_chunks = truncate_list_by_token_size(
1787
  valid_chunks,
1788
  key=lambda x: x["content"],
@@ -1966,6 +1975,15 @@ async def _get_node_data(
1966
  if not len(results):
1967
  return "", "", ""
1968
 
 
 
 
 
 
 
 
 
 
1969
  # Extract all entity IDs from your results list
1970
  node_ids = [r["entity_name"] for r in results]
1971
 
@@ -2269,6 +2287,15 @@ async def _get_edge_data(
2269
  if not len(results):
2270
  return "", "", ""
2271
 
 
 
 
 
 
 
 
 
 
2272
  # Prepare edge pairs in two forms:
2273
  # For the batch edge properties function, use dicts.
2274
  edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
@@ -2806,3 +2833,61 @@ async def query_with_keywords(
2806
  )
2807
  else:
2808
  raise ValueError(f"Unknown mode {param.mode}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1783
  if not valid_chunks:
1784
  return [], [], []
1785
 
1786
+ # Apply reranking if enabled
1787
+ global_config = chunks_vdb.global_config
1788
+ valid_chunks = await apply_rerank_if_enabled(
1789
+ query=query,
1790
+ retrieved_docs=valid_chunks,
1791
+ global_config=global_config,
1792
+ top_k=query_param.top_k,
1793
+ )
1794
+
1795
  maybe_trun_chunks = truncate_list_by_token_size(
1796
  valid_chunks,
1797
  key=lambda x: x["content"],
 
1975
  if not len(results):
1976
  return "", "", ""
1977
 
1978
+ # Apply reranking if enabled for entity results
1979
+ global_config = entities_vdb.global_config
1980
+ results = await apply_rerank_if_enabled(
1981
+ query=query,
1982
+ retrieved_docs=results,
1983
+ global_config=global_config,
1984
+ top_k=query_param.top_k,
1985
+ )
1986
+
1987
  # Extract all entity IDs from your results list
1988
  node_ids = [r["entity_name"] for r in results]
1989
 
 
2287
  if not len(results):
2288
  return "", "", ""
2289
 
2290
+ # Apply reranking if enabled for relationship results
2291
+ global_config = relationships_vdb.global_config
2292
+ results = await apply_rerank_if_enabled(
2293
+ query=keywords,
2294
+ retrieved_docs=results,
2295
+ global_config=global_config,
2296
+ top_k=query_param.top_k,
2297
+ )
2298
+
2299
  # Prepare edge pairs in two forms:
2300
  # For the batch edge properties function, use dicts.
2301
  edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
 
2833
  )
2834
  else:
2835
  raise ValueError(f"Unknown mode {param.mode}")
2836
+
2837
+
2838
+ async def apply_rerank_if_enabled(
2839
+ query: str,
2840
+ retrieved_docs: list[dict],
2841
+ global_config: dict,
2842
+ top_k: int = None,
2843
+ ) -> list[dict]:
2844
+ """
2845
+ Apply reranking to retrieved documents if rerank is enabled.
2846
+
2847
+ Args:
2848
+ query: The search query
2849
+ retrieved_docs: List of retrieved documents
2850
+ global_config: Global configuration containing rerank settings
2851
+ top_k: Number of top documents to return after reranking
2852
+
2853
+ Returns:
2854
+ Reranked documents if rerank is enabled, otherwise original documents
2855
+ """
2856
+ if not global_config.get("enable_rerank", False) or not retrieved_docs:
2857
+ return retrieved_docs
2858
+
2859
+ rerank_func = global_config.get("rerank_model_func")
2860
+ if not rerank_func:
2861
+ logger.debug(
2862
+ "Rerank is enabled but no rerank function provided, skipping rerank"
2863
+ )
2864
+ return retrieved_docs
2865
+
2866
+ try:
2867
+ # Determine top_k for reranking
2868
+ rerank_top_k = top_k or global_config.get("rerank_top_k", 10)
2869
+ rerank_top_k = min(rerank_top_k, len(retrieved_docs))
2870
+
2871
+ logger.debug(
2872
+ f"Applying rerank to {len(retrieved_docs)} documents, returning top {rerank_top_k}"
2873
+ )
2874
+
2875
+ # Apply reranking
2876
+ reranked_docs = await rerank_func(
2877
+ query=query,
2878
+ documents=retrieved_docs,
2879
+ top_k=rerank_top_k,
2880
+ )
2881
+
2882
+ if reranked_docs and len(reranked_docs) > 0:
2883
+ logger.info(
2884
+ f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
2885
+ )
2886
+ return reranked_docs
2887
+ else:
2888
+ logger.warning("Rerank returned empty results, using original documents")
2889
+ return retrieved_docs[:rerank_top_k] if rerank_top_k else retrieved_docs
2890
+
2891
+ except Exception as e:
2892
+ logger.error(f"Error during reranking: {e}, using original documents")
2893
+ return retrieved_docs
lightrag/rerank.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import json
5
+ import aiohttp
6
+ import numpy as np
7
+ from typing import Callable, Any, List, Dict, Optional
8
+ from pydantic import BaseModel, Field
9
+ from dataclasses import asdict
10
+
11
+ from .utils import logger
12
+
13
+
14
+ class RerankModel(BaseModel):
15
+ """
16
+ Pydantic model class for defining a custom rerank model.
17
+
18
+ Attributes:
19
+ rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents.
20
+ The function should take query and documents as input and return reranked results.
21
+ kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
22
+ This could include parameters such as the model name, API key, etc.
23
+
24
+ Example usage:
25
+ Rerank model example from jina:
26
+ ```python
27
+ rerank_model = RerankModel(
28
+ rerank_func=jina_rerank,
29
+ kwargs={
30
+ "model": "BAAI/bge-reranker-v2-m3",
31
+ "api_key": "your_api_key_here",
32
+ "base_url": "https://api.jina.ai/v1/rerank"
33
+ }
34
+ )
35
+ ```
36
+ """
37
+
38
+ rerank_func: Callable[[Any], List[Dict]]
39
+ kwargs: Dict[str, Any] = Field(default_factory=dict)
40
+
41
+ async def rerank(
42
+ self,
43
+ query: str,
44
+ documents: List[Dict[str, Any]],
45
+ top_k: Optional[int] = None,
46
+ **extra_kwargs
47
+ ) -> List[Dict[str, Any]]:
48
+ """Rerank documents using the configured model function."""
49
+ # Merge extra kwargs with model kwargs
50
+ kwargs = {**self.kwargs, **extra_kwargs}
51
+ return await self.rerank_func(
52
+ query=query,
53
+ documents=documents,
54
+ top_k=top_k,
55
+ **kwargs
56
+ )
57
+
58
+
59
+ class MultiRerankModel(BaseModel):
60
+ """Multiple rerank models for different modes/scenarios."""
61
+
62
+ # Primary rerank model (used if mode-specific models are not defined)
63
+ rerank_model: Optional[RerankModel] = None
64
+
65
+ # Mode-specific rerank models
66
+ entity_rerank_model: Optional[RerankModel] = None
67
+ relation_rerank_model: Optional[RerankModel] = None
68
+ chunk_rerank_model: Optional[RerankModel] = None
69
+
70
+ async def rerank(
71
+ self,
72
+ query: str,
73
+ documents: List[Dict[str, Any]],
74
+ mode: str = "default",
75
+ top_k: Optional[int] = None,
76
+ **kwargs
77
+ ) -> List[Dict[str, Any]]:
78
+ """Rerank using the appropriate model based on mode."""
79
+
80
+ # Select model based on mode
81
+ if mode == "entity" and self.entity_rerank_model:
82
+ model = self.entity_rerank_model
83
+ elif mode == "relation" and self.relation_rerank_model:
84
+ model = self.relation_rerank_model
85
+ elif mode == "chunk" and self.chunk_rerank_model:
86
+ model = self.chunk_rerank_model
87
+ elif self.rerank_model:
88
+ model = self.rerank_model
89
+ else:
90
+ logger.warning(f"No rerank model available for mode: {mode}")
91
+ return documents
92
+
93
+ return await model.rerank(query, documents, top_k, **kwargs)
94
+
95
+
96
+ async def generic_rerank_api(
97
+ query: str,
98
+ documents: List[Dict[str, Any]],
99
+ model: str,
100
+ base_url: str,
101
+ api_key: str,
102
+ top_k: Optional[int] = None,
103
+ **kwargs
104
+ ) -> List[Dict[str, Any]]:
105
+ """
106
+ Generic rerank function that works with Jina/Cohere compatible APIs.
107
+
108
+ Args:
109
+ query: The search query
110
+ documents: List of documents to rerank
111
+ model: Model identifier
112
+ base_url: API endpoint URL
113
+ api_key: API authentication key
114
+ top_k: Number of top results to return
115
+ **kwargs: Additional API-specific parameters
116
+
117
+ Returns:
118
+ List of reranked documents with relevance scores
119
+ """
120
+ if not api_key:
121
+ logger.warning("No API key provided for rerank service")
122
+ return documents
123
+
124
+ if not documents:
125
+ return documents
126
+
127
+ # Prepare documents for reranking - handle both text and dict formats
128
+ prepared_docs = []
129
+ for doc in documents:
130
+ if isinstance(doc, dict):
131
+ # Use 'content' field if available, otherwise use 'text' or convert to string
132
+ text = doc.get('content') or doc.get('text') or str(doc)
133
+ else:
134
+ text = str(doc)
135
+ prepared_docs.append(text)
136
+
137
+ # Prepare request
138
+ headers = {
139
+ "Content-Type": "application/json",
140
+ "Authorization": f"Bearer {api_key}"
141
+ }
142
+
143
+ data = {
144
+ "model": model,
145
+ "query": query,
146
+ "documents": prepared_docs,
147
+ **kwargs
148
+ }
149
+
150
+ if top_k is not None:
151
+ data["top_k"] = min(top_k, len(prepared_docs))
152
+
153
+ try:
154
+ async with aiohttp.ClientSession() as session:
155
+ async with session.post(base_url, headers=headers, json=data) as response:
156
+ if response.status != 200:
157
+ error_text = await response.text()
158
+ logger.error(f"Rerank API error {response.status}: {error_text}")
159
+ return documents
160
+
161
+ result = await response.json()
162
+
163
+ # Extract reranked results
164
+ if "results" in result:
165
+ # Standard format: results contain index and relevance_score
166
+ reranked_docs = []
167
+ for item in result["results"]:
168
+ if "index" in item:
169
+ doc_idx = item["index"]
170
+ if 0 <= doc_idx < len(documents):
171
+ reranked_doc = documents[doc_idx].copy()
172
+ if "relevance_score" in item:
173
+ reranked_doc["rerank_score"] = item["relevance_score"]
174
+ reranked_docs.append(reranked_doc)
175
+ return reranked_docs
176
+ else:
177
+ logger.warning("Unexpected rerank API response format")
178
+ return documents
179
+
180
+ except Exception as e:
181
+ logger.error(f"Error during reranking: {e}")
182
+ return documents
183
+
184
+
185
+ async def jina_rerank(
186
+ query: str,
187
+ documents: List[Dict[str, Any]],
188
+ model: str = "BAAI/bge-reranker-v2-m3",
189
+ top_k: Optional[int] = None,
190
+ base_url: str = "https://api.jina.ai/v1/rerank",
191
+ api_key: Optional[str] = None,
192
+ **kwargs
193
+ ) -> List[Dict[str, Any]]:
194
+ """
195
+ Rerank documents using Jina AI API.
196
+
197
+ Args:
198
+ query: The search query
199
+ documents: List of documents to rerank
200
+ model: Jina rerank model name
201
+ top_k: Number of top results to return
202
+ base_url: Jina API endpoint
203
+ api_key: Jina API key
204
+ **kwargs: Additional parameters
205
+
206
+ Returns:
207
+ List of reranked documents with relevance scores
208
+ """
209
+ if api_key is None:
210
+ api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
211
+
212
+ return await generic_rerank_api(
213
+ query=query,
214
+ documents=documents,
215
+ model=model,
216
+ base_url=base_url,
217
+ api_key=api_key,
218
+ top_k=top_k,
219
+ **kwargs
220
+ )
221
+
222
+
223
+ async def cohere_rerank(
224
+ query: str,
225
+ documents: List[Dict[str, Any]],
226
+ model: str = "rerank-english-v2.0",
227
+ top_k: Optional[int] = None,
228
+ base_url: str = "https://api.cohere.ai/v1/rerank",
229
+ api_key: Optional[str] = None,
230
+ **kwargs
231
+ ) -> List[Dict[str, Any]]:
232
+ """
233
+ Rerank documents using Cohere API.
234
+
235
+ Args:
236
+ query: The search query
237
+ documents: List of documents to rerank
238
+ model: Cohere rerank model name
239
+ top_k: Number of top results to return
240
+ base_url: Cohere API endpoint
241
+ api_key: Cohere API key
242
+ **kwargs: Additional parameters
243
+
244
+ Returns:
245
+ List of reranked documents with relevance scores
246
+ """
247
+ if api_key is None:
248
+ api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
249
+
250
+ return await generic_rerank_api(
251
+ query=query,
252
+ documents=documents,
253
+ model=model,
254
+ base_url=base_url,
255
+ api_key=api_key,
256
+ top_k=top_k,
257
+ **kwargs
258
+ )
259
+
260
+
261
+ # Convenience function for custom API endpoints
262
+ async def custom_rerank(
263
+ query: str,
264
+ documents: List[Dict[str, Any]],
265
+ model: str,
266
+ base_url: str,
267
+ api_key: str,
268
+ top_k: Optional[int] = None,
269
+ **kwargs
270
+ ) -> List[Dict[str, Any]]:
271
+ """
272
+ Rerank documents using a custom API endpoint.
273
+ This is useful for self-hosted or custom rerank services.
274
+ """
275
+ return await generic_rerank_api(
276
+ query=query,
277
+ documents=documents,
278
+ model=model,
279
+ base_url=base_url,
280
+ api_key=api_key,
281
+ top_k=top_k,
282
+ **kwargs
283
+ )
284
+
285
+
286
+ if __name__ == "__main__":
287
+ import asyncio
288
+
289
+ async def main():
290
+ # Example usage
291
+ docs = [
292
+ {"content": "The capital of France is Paris."},
293
+ {"content": "Tokyo is the capital of Japan."},
294
+ {"content": "London is the capital of England."},
295
+ ]
296
+
297
+ query = "What is the capital of France?"
298
+
299
+ result = await jina_rerank(
300
+ query=query,
301
+ documents=docs,
302
+ top_k=2,
303
+ api_key="your-api-key-here"
304
+ )
305
+ print(result)
306
+
307
+ asyncio.run(main())