zrguo
commited on
Commit
·
24a98c3
1
Parent(s):
fc091ae
Simplify Configuration
Browse files- docs/rerank_integration.md +48 -48
- env.example +0 -8
- examples/rerank_example.py +76 -43
- lightrag/lightrag.py +1 -27
- lightrag/operate.py +4 -8
- lightrag/rerank.py +72 -58
docs/rerank_integration.md
CHANGED
@@ -2,24 +2,15 @@
|
|
2 |
|
3 |
This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality.
|
4 |
|
5 |
-
## ⚠️ Important: Parameter Priority
|
6 |
-
|
7 |
-
**QueryParam.top_k has higher priority than rerank_top_k configuration:**
|
8 |
-
|
9 |
-
- When you set `QueryParam(top_k=5)`, it will override the `rerank_top_k=10` setting in LightRAG configuration
|
10 |
-
- This means the actual number of documents sent to rerank will be determined by QueryParam.top_k
|
11 |
-
- For optimal rerank performance, always consider the top_k value in your QueryParam calls
|
12 |
-
- Example: `rag.aquery(query, param=QueryParam(mode="naive", top_k=20))` will use 20, not rerank_top_k
|
13 |
-
|
14 |
## Overview
|
15 |
|
16 |
Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix).
|
17 |
|
18 |
## Architecture
|
19 |
|
20 |
-
The rerank integration follows
|
21 |
|
22 |
-
- **
|
23 |
- **Async Processing**: Non-blocking rerank operations
|
24 |
- **Error Handling**: Graceful fallback to original results
|
25 |
- **Optional Feature**: Can be enabled/disabled via configuration
|
@@ -29,24 +20,11 @@ The rerank integration follows the same design pattern as the LLM integration:
|
|
29 |
|
30 |
### Environment Variables
|
31 |
|
32 |
-
Set
|
33 |
|
34 |
```bash
|
35 |
# Enable/disable reranking
|
36 |
ENABLE_RERANK=True
|
37 |
-
|
38 |
-
# Rerank model configuration
|
39 |
-
RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
40 |
-
RERANK_MAX_ASYNC=4
|
41 |
-
RERANK_TOP_K=10
|
42 |
-
|
43 |
-
# API configuration
|
44 |
-
RERANK_API_KEY=your_rerank_api_key_here
|
45 |
-
RERANK_BASE_URL=https://api.your-provider.com/v1/rerank
|
46 |
-
|
47 |
-
# Provider-specific keys (optional alternatives)
|
48 |
-
JINA_API_KEY=your_jina_api_key_here
|
49 |
-
COHERE_API_KEY=your_cohere_api_key_here
|
50 |
```
|
51 |
|
52 |
### Programmatic Configuration
|
@@ -55,15 +33,27 @@ COHERE_API_KEY=your_cohere_api_key_here
|
|
55 |
from lightrag import LightRAG
|
56 |
from lightrag.rerank import custom_rerank, RerankModel
|
57 |
|
58 |
-
# Method 1: Using
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
rag = LightRAG(
|
60 |
working_dir="./rag_storage",
|
61 |
llm_model_func=your_llm_func,
|
62 |
embedding_func=your_embedding_func,
|
63 |
-
|
|
|
64 |
)
|
65 |
|
66 |
-
# Method 2:
|
67 |
rerank_model = RerankModel(
|
68 |
rerank_func=custom_rerank,
|
69 |
kwargs={
|
@@ -79,7 +69,6 @@ rag = LightRAG(
|
|
79 |
embedding_func=your_embedding_func,
|
80 |
enable_rerank=True,
|
81 |
rerank_model_func=rerank_model.rerank,
|
82 |
-
rerank_top_k=10,
|
83 |
)
|
84 |
```
|
85 |
|
@@ -112,7 +101,8 @@ result = await jina_rerank(
|
|
112 |
query="your query",
|
113 |
documents=documents,
|
114 |
model="BAAI/bge-reranker-v2-m3",
|
115 |
-
api_key="your_jina_api_key"
|
|
|
116 |
)
|
117 |
```
|
118 |
|
@@ -125,7 +115,8 @@ result = await cohere_rerank(
|
|
125 |
query="your query",
|
126 |
documents=documents,
|
127 |
model="rerank-english-v2.0",
|
128 |
-
api_key="your_cohere_api_key"
|
|
|
129 |
)
|
130 |
```
|
131 |
|
@@ -143,11 +134,7 @@ Reranking is automatically applied at these key retrieval stages:
|
|
143 |
| Parameter | Type | Default | Description |
|
144 |
|-----------|------|---------|-------------|
|
145 |
| `enable_rerank` | bool | False | Enable/disable reranking |
|
146 |
-
| `
|
147 |
-
| `rerank_model_max_async` | int | 4 | Max concurrent rerank calls |
|
148 |
-
| `rerank_top_k` | int | 10 | Number of top results to return ⚠️ **Overridden by QueryParam.top_k** |
|
149 |
-
| `rerank_model_func` | callable | None | Custom rerank function |
|
150 |
-
| `rerank_model_kwargs` | dict | {} | Additional rerank parameters |
|
151 |
|
152 |
## Example Usage
|
153 |
|
@@ -157,6 +144,18 @@ Reranking is automatically applied at these key retrieval stages:
|
|
157 |
import asyncio
|
158 |
from lightrag import LightRAG, QueryParam
|
159 |
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
async def main():
|
162 |
# Initialize with rerank enabled
|
@@ -165,20 +164,21 @@ async def main():
|
|
165 |
llm_model_func=gpt_4o_mini_complete,
|
166 |
embedding_func=openai_embedding,
|
167 |
enable_rerank=True,
|
|
|
168 |
)
|
169 |
-
|
170 |
# Insert documents
|
171 |
await rag.ainsert([
|
172 |
"Document 1 content...",
|
173 |
"Document 2 content...",
|
174 |
])
|
175 |
-
|
176 |
# Query with rerank (automatically applied)
|
177 |
result = await rag.aquery(
|
178 |
"Your question here",
|
179 |
-
param=QueryParam(mode="hybrid", top_k=5) #
|
180 |
)
|
181 |
-
|
182 |
print(result)
|
183 |
|
184 |
asyncio.run(main())
|
@@ -195,7 +195,7 @@ async def test_rerank():
|
|
195 |
{"content": "Text about topic B"},
|
196 |
{"content": "Text about topic C"},
|
197 |
]
|
198 |
-
|
199 |
reranked = await custom_rerank(
|
200 |
query="Tell me about topic A",
|
201 |
documents=documents,
|
@@ -204,26 +204,26 @@ async def test_rerank():
|
|
204 |
api_key="your_api_key_here",
|
205 |
top_k=2
|
206 |
)
|
207 |
-
|
208 |
for doc in reranked:
|
209 |
print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}")
|
210 |
```
|
211 |
|
212 |
## Best Practices
|
213 |
|
214 |
-
1. **
|
215 |
2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff
|
216 |
-
3. **API Limits**: Monitor API usage and implement rate limiting
|
217 |
4. **Fallback**: Always handle rerank failures gracefully (returns original results)
|
218 |
-
5. **Top-k
|
219 |
6. **Cost Management**: Consider rerank API costs in your budget planning
|
220 |
|
221 |
## Troubleshooting
|
222 |
|
223 |
### Common Issues
|
224 |
|
225 |
-
1. **API Key Missing**: Ensure
|
226 |
-
2. **Network Issues**: Check
|
227 |
3. **Model Errors**: Verify the rerank model name is supported by your API
|
228 |
4. **Document Format**: Ensure documents have `content` or `text` fields
|
229 |
|
@@ -268,4 +268,4 @@ The generic rerank API expects this response format:
|
|
268 |
This is compatible with:
|
269 |
- Jina AI Rerank API
|
270 |
- Cohere Rerank API
|
271 |
-
- Custom APIs following the same format
|
|
|
2 |
|
3 |
This document explains how to configure and use the rerank functionality in LightRAG to improve retrieval quality.
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
## Overview
|
6 |
|
7 |
Reranking is an optional feature that improves the quality of retrieved documents by re-ordering them based on their relevance to the query. This is particularly useful when you want higher precision in document retrieval across all query modes (naive, local, global, hybrid, mix).
|
8 |
|
9 |
## Architecture
|
10 |
|
11 |
+
The rerank integration follows a simplified design pattern:
|
12 |
|
13 |
+
- **Single Function Configuration**: All rerank settings (model, API keys, top_k, etc.) are contained within the rerank function
|
14 |
- **Async Processing**: Non-blocking rerank operations
|
15 |
- **Error Handling**: Graceful fallback to original results
|
16 |
- **Optional Feature**: Can be enabled/disabled via configuration
|
|
|
20 |
|
21 |
### Environment Variables
|
22 |
|
23 |
+
Set this variable in your `.env` file or environment:
|
24 |
|
25 |
```bash
|
26 |
# Enable/disable reranking
|
27 |
ENABLE_RERANK=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
```
|
29 |
|
30 |
### Programmatic Configuration
|
|
|
33 |
from lightrag import LightRAG
|
34 |
from lightrag.rerank import custom_rerank, RerankModel
|
35 |
|
36 |
+
# Method 1: Using a custom rerank function with all settings included
|
37 |
+
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
|
38 |
+
return await custom_rerank(
|
39 |
+
query=query,
|
40 |
+
documents=documents,
|
41 |
+
model="BAAI/bge-reranker-v2-m3",
|
42 |
+
base_url="https://api.your-provider.com/v1/rerank",
|
43 |
+
api_key="your_api_key_here",
|
44 |
+
top_k=top_k or 10, # Handle top_k within the function
|
45 |
+
**kwargs
|
46 |
+
)
|
47 |
+
|
48 |
rag = LightRAG(
|
49 |
working_dir="./rag_storage",
|
50 |
llm_model_func=your_llm_func,
|
51 |
embedding_func=your_embedding_func,
|
52 |
+
enable_rerank=True,
|
53 |
+
rerank_model_func=my_rerank_func,
|
54 |
)
|
55 |
|
56 |
+
# Method 2: Using RerankModel wrapper
|
57 |
rerank_model = RerankModel(
|
58 |
rerank_func=custom_rerank,
|
59 |
kwargs={
|
|
|
69 |
embedding_func=your_embedding_func,
|
70 |
enable_rerank=True,
|
71 |
rerank_model_func=rerank_model.rerank,
|
|
|
72 |
)
|
73 |
```
|
74 |
|
|
|
101 |
query="your query",
|
102 |
documents=documents,
|
103 |
model="BAAI/bge-reranker-v2-m3",
|
104 |
+
api_key="your_jina_api_key",
|
105 |
+
top_k=10
|
106 |
)
|
107 |
```
|
108 |
|
|
|
115 |
query="your query",
|
116 |
documents=documents,
|
117 |
model="rerank-english-v2.0",
|
118 |
+
api_key="your_cohere_api_key",
|
119 |
+
top_k=10
|
120 |
)
|
121 |
```
|
122 |
|
|
|
134 |
| Parameter | Type | Default | Description |
|
135 |
|-----------|------|---------|-------------|
|
136 |
| `enable_rerank` | bool | False | Enable/disable reranking |
|
137 |
+
| `rerank_model_func` | callable | None | Custom rerank function containing all configurations (model, API keys, top_k, etc.) |
|
|
|
|
|
|
|
|
|
138 |
|
139 |
## Example Usage
|
140 |
|
|
|
144 |
import asyncio
|
145 |
from lightrag import LightRAG, QueryParam
|
146 |
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embedding
|
147 |
+
from lightrag.rerank import jina_rerank
|
148 |
+
|
149 |
+
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
|
150 |
+
"""Custom rerank function with all settings included"""
|
151 |
+
return await jina_rerank(
|
152 |
+
query=query,
|
153 |
+
documents=documents,
|
154 |
+
model="BAAI/bge-reranker-v2-m3",
|
155 |
+
api_key="your_jina_api_key_here",
|
156 |
+
top_k=top_k or 10, # Default top_k if not provided
|
157 |
+
**kwargs
|
158 |
+
)
|
159 |
|
160 |
async def main():
|
161 |
# Initialize with rerank enabled
|
|
|
164 |
llm_model_func=gpt_4o_mini_complete,
|
165 |
embedding_func=openai_embedding,
|
166 |
enable_rerank=True,
|
167 |
+
rerank_model_func=my_rerank_func,
|
168 |
)
|
169 |
+
|
170 |
# Insert documents
|
171 |
await rag.ainsert([
|
172 |
"Document 1 content...",
|
173 |
"Document 2 content...",
|
174 |
])
|
175 |
+
|
176 |
# Query with rerank (automatically applied)
|
177 |
result = await rag.aquery(
|
178 |
"Your question here",
|
179 |
+
param=QueryParam(mode="hybrid", top_k=5) # This top_k is passed to rerank function
|
180 |
)
|
181 |
+
|
182 |
print(result)
|
183 |
|
184 |
asyncio.run(main())
|
|
|
195 |
{"content": "Text about topic B"},
|
196 |
{"content": "Text about topic C"},
|
197 |
]
|
198 |
+
|
199 |
reranked = await custom_rerank(
|
200 |
query="Tell me about topic A",
|
201 |
documents=documents,
|
|
|
204 |
api_key="your_api_key_here",
|
205 |
top_k=2
|
206 |
)
|
207 |
+
|
208 |
for doc in reranked:
|
209 |
print(f"Score: {doc.get('rerank_score')}, Content: {doc.get('content')}")
|
210 |
```
|
211 |
|
212 |
## Best Practices
|
213 |
|
214 |
+
1. **Self-Contained Functions**: Include all necessary configurations (API keys, models, top_k handling) within your rerank function
|
215 |
2. **Performance**: Use reranking selectively for better performance vs. quality tradeoff
|
216 |
+
3. **API Limits**: Monitor API usage and implement rate limiting within your rerank function
|
217 |
4. **Fallback**: Always handle rerank failures gracefully (returns original results)
|
218 |
+
5. **Top-k Handling**: Handle top_k parameter appropriately within your rerank function
|
219 |
6. **Cost Management**: Consider rerank API costs in your budget planning
|
220 |
|
221 |
## Troubleshooting
|
222 |
|
223 |
### Common Issues
|
224 |
|
225 |
+
1. **API Key Missing**: Ensure API keys are properly configured within your rerank function
|
226 |
+
2. **Network Issues**: Check API endpoints and network connectivity
|
227 |
3. **Model Errors**: Verify the rerank model name is supported by your API
|
228 |
4. **Document Format**: Ensure documents have `content` or `text` fields
|
229 |
|
|
|
268 |
This is compatible with:
|
269 |
- Jina AI Rerank API
|
270 |
- Cohere Rerank API
|
271 |
+
- Custom APIs following the same format
|
env.example
CHANGED
@@ -182,11 +182,3 @@ REDIS_URI=redis://localhost:6379
|
|
182 |
|
183 |
# Rerank Configuration
|
184 |
ENABLE_RERANK=False
|
185 |
-
RERANK_MODEL=BAAI/bge-reranker-v2-m3
|
186 |
-
RERANK_MAX_ASYNC=4
|
187 |
-
RERANK_TOP_K=10
|
188 |
-
# Note: QueryParam.top_k in your code will override RERANK_TOP_K setting
|
189 |
-
|
190 |
-
# Rerank API Configuration
|
191 |
-
RERANK_API_KEY=your_rerank_api_key_here
|
192 |
-
RERANK_BASE_URL=https://api.your-provider.com/v1/rerank
|
|
|
182 |
|
183 |
# Rerank Configuration
|
184 |
ENABLE_RERANK=False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/rerank_example.py
CHANGED
@@ -4,19 +4,12 @@ LightRAG Rerank Integration Example
|
|
4 |
This example demonstrates how to use rerank functionality with LightRAG
|
5 |
to improve retrieval quality across different query modes.
|
6 |
|
7 |
-
IMPORTANT: Parameter Priority
|
8 |
-
- QueryParam(top_k=N) has higher priority than rerank_top_k in LightRAG configuration
|
9 |
-
- If you set QueryParam(top_k=5), it will override rerank_top_k setting
|
10 |
-
- For optimal rerank performance, use appropriate top_k values in QueryParam
|
11 |
-
|
12 |
Configuration Required:
|
13 |
1. Set your LLM API key and base URL in llm_model_func()
|
14 |
-
2. Set your embedding API key and base URL in embedding_func()
|
15 |
3. Set your rerank API key and base URL in the rerank configuration
|
16 |
4. Or use environment variables (.env file):
|
17 |
-
-
|
18 |
-
- RERANK_BASE_URL=https://your-actual-rerank-endpoint/v1/rerank
|
19 |
-
- RERANK_MODEL=your_rerank_model_name
|
20 |
"""
|
21 |
|
22 |
import asyncio
|
@@ -35,6 +28,7 @@ setup_logger("test_rerank")
|
|
35 |
if not os.path.exists(WORKING_DIR):
|
36 |
os.mkdir(WORKING_DIR)
|
37 |
|
|
|
38 |
async def llm_model_func(
|
39 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
40 |
) -> str:
|
@@ -48,6 +42,7 @@ async def llm_model_func(
|
|
48 |
**kwargs,
|
49 |
)
|
50 |
|
|
|
51 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
52 |
return await openai_embed(
|
53 |
texts,
|
@@ -56,25 +51,63 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
|
56 |
base_url="https://api.your-embedding-provider.com/v1",
|
57 |
)
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
async def create_rag_with_rerank():
|
60 |
"""Create LightRAG instance with rerank configuration"""
|
61 |
-
|
62 |
# Get embedding dimension
|
63 |
test_embedding = await embedding_func(["test"])
|
64 |
embedding_dim = test_embedding.shape[1]
|
65 |
print(f"Detected embedding dimension: {embedding_dim}")
|
66 |
|
67 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
rerank_model = RerankModel(
|
69 |
rerank_func=custom_rerank,
|
70 |
kwargs={
|
71 |
"model": "BAAI/bge-reranker-v2-m3",
|
72 |
"base_url": "https://api.your-rerank-provider.com/v1/rerank",
|
73 |
"api_key": "your_rerank_api_key_here",
|
74 |
-
}
|
75 |
)
|
76 |
|
77 |
-
# Initialize LightRAG with rerank
|
78 |
rag = LightRAG(
|
79 |
working_dir=WORKING_DIR,
|
80 |
llm_model_func=llm_model_func,
|
@@ -83,69 +116,66 @@ async def create_rag_with_rerank():
|
|
83 |
max_token_size=8192,
|
84 |
func=embedding_func,
|
85 |
),
|
86 |
-
# Rerank Configuration
|
87 |
enable_rerank=True,
|
88 |
rerank_model_func=rerank_model.rerank,
|
89 |
-
rerank_top_k=10, # Note: QueryParam.top_k will override this
|
90 |
)
|
91 |
|
92 |
return rag
|
93 |
|
|
|
94 |
async def test_rerank_with_different_topk():
|
95 |
"""
|
96 |
-
Test rerank functionality with different top_k settings
|
97 |
"""
|
98 |
print("🚀 Setting up LightRAG with Rerank functionality...")
|
99 |
-
|
100 |
rag = await create_rag_with_rerank()
|
101 |
-
|
102 |
# Insert sample documents
|
103 |
sample_docs = [
|
104 |
"Reranking improves retrieval quality by re-ordering documents based on relevance.",
|
105 |
"LightRAG is a powerful retrieval-augmented generation system with multiple query modes.",
|
106 |
"Vector databases enable efficient similarity search in high-dimensional embedding spaces.",
|
107 |
"Natural language processing has evolved with large language models and transformers.",
|
108 |
-
"Machine learning algorithms can learn patterns from data without explicit programming."
|
109 |
]
|
110 |
-
|
111 |
print("📄 Inserting sample documents...")
|
112 |
await rag.ainsert(sample_docs)
|
113 |
-
|
114 |
query = "How does reranking improve retrieval quality?"
|
115 |
print(f"\n🔍 Testing query: '{query}'")
|
116 |
print("=" * 80)
|
117 |
-
|
118 |
# Test different top_k values to show parameter priority
|
119 |
top_k_values = [2, 5, 10]
|
120 |
-
|
121 |
for top_k in top_k_values:
|
122 |
-
print(f"\n📊 Testing with QueryParam(top_k={top_k})
|
123 |
-
|
124 |
# Test naive mode with specific top_k
|
125 |
-
result = await rag.aquery(
|
126 |
-
query,
|
127 |
-
param=QueryParam(mode="naive", top_k=top_k)
|
128 |
-
)
|
129 |
print(f" Result length: {len(result)} characters")
|
130 |
print(f" Preview: {result[:100]}...")
|
131 |
|
|
|
132 |
async def test_direct_rerank():
|
133 |
"""Test rerank function directly"""
|
134 |
print("\n🔧 Direct Rerank API Test")
|
135 |
print("=" * 40)
|
136 |
-
|
137 |
documents = [
|
138 |
{"content": "Reranking significantly improves retrieval quality"},
|
139 |
{"content": "LightRAG supports advanced reranking capabilities"},
|
140 |
{"content": "Vector search finds semantically similar documents"},
|
141 |
{"content": "Natural language processing with modern transformers"},
|
142 |
-
{"content": "The quick brown fox jumps over the lazy dog"}
|
143 |
]
|
144 |
-
|
145 |
query = "rerank improve quality"
|
146 |
print(f"Query: '{query}'")
|
147 |
print(f"Documents: {len(documents)}")
|
148 |
-
|
149 |
try:
|
150 |
reranked_docs = await custom_rerank(
|
151 |
query=query,
|
@@ -153,41 +183,44 @@ async def test_direct_rerank():
|
|
153 |
model="BAAI/bge-reranker-v2-m3",
|
154 |
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
155 |
api_key="your_rerank_api_key_here",
|
156 |
-
top_k=3
|
157 |
)
|
158 |
-
|
159 |
print("\n✅ Rerank Results:")
|
160 |
for i, doc in enumerate(reranked_docs):
|
161 |
score = doc.get("rerank_score", "N/A")
|
162 |
content = doc.get("content", "")[:60]
|
163 |
print(f" {i+1}. Score: {score:.4f} | {content}...")
|
164 |
-
|
165 |
except Exception as e:
|
166 |
print(f"❌ Rerank failed: {e}")
|
167 |
|
|
|
168 |
async def main():
|
169 |
"""Main example function"""
|
170 |
print("🎯 LightRAG Rerank Integration Example")
|
171 |
print("=" * 60)
|
172 |
-
|
173 |
try:
|
174 |
# Test rerank with different top_k values
|
175 |
await test_rerank_with_different_topk()
|
176 |
-
|
177 |
# Test direct rerank
|
178 |
await test_direct_rerank()
|
179 |
-
|
180 |
print("\n✅ Example completed successfully!")
|
181 |
print("\n💡 Key Points:")
|
182 |
-
print(" ✓
|
183 |
print(" ✓ Rerank improves document relevance ordering")
|
184 |
-
print(" ✓ Configure API keys
|
185 |
print(" ✓ Monitor API usage and costs when using rerank services")
|
186 |
-
|
187 |
except Exception as e:
|
188 |
print(f"\n❌ Example failed: {e}")
|
189 |
import traceback
|
|
|
190 |
traceback.print_exc()
|
191 |
|
|
|
192 |
if __name__ == "__main__":
|
193 |
-
asyncio.run(main())
|
|
|
4 |
This example demonstrates how to use rerank functionality with LightRAG
|
5 |
to improve retrieval quality across different query modes.
|
6 |
|
|
|
|
|
|
|
|
|
|
|
7 |
Configuration Required:
|
8 |
1. Set your LLM API key and base URL in llm_model_func()
|
9 |
+
2. Set your embedding API key and base URL in embedding_func()
|
10 |
3. Set your rerank API key and base URL in the rerank configuration
|
11 |
4. Or use environment variables (.env file):
|
12 |
+
- ENABLE_RERANK=True
|
|
|
|
|
13 |
"""
|
14 |
|
15 |
import asyncio
|
|
|
28 |
if not os.path.exists(WORKING_DIR):
|
29 |
os.mkdir(WORKING_DIR)
|
30 |
|
31 |
+
|
32 |
async def llm_model_func(
|
33 |
prompt, system_prompt=None, history_messages=[], **kwargs
|
34 |
) -> str:
|
|
|
42 |
**kwargs,
|
43 |
)
|
44 |
|
45 |
+
|
46 |
async def embedding_func(texts: list[str]) -> np.ndarray:
|
47 |
return await openai_embed(
|
48 |
texts,
|
|
|
51 |
base_url="https://api.your-embedding-provider.com/v1",
|
52 |
)
|
53 |
|
54 |
+
|
55 |
+
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
|
56 |
+
"""Custom rerank function with all settings included"""
|
57 |
+
return await custom_rerank(
|
58 |
+
query=query,
|
59 |
+
documents=documents,
|
60 |
+
model="BAAI/bge-reranker-v2-m3",
|
61 |
+
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
62 |
+
api_key="your_rerank_api_key_here",
|
63 |
+
top_k=top_k or 10, # Default top_k if not provided
|
64 |
+
**kwargs,
|
65 |
+
)
|
66 |
+
|
67 |
+
|
68 |
async def create_rag_with_rerank():
|
69 |
"""Create LightRAG instance with rerank configuration"""
|
70 |
+
|
71 |
# Get embedding dimension
|
72 |
test_embedding = await embedding_func(["test"])
|
73 |
embedding_dim = test_embedding.shape[1]
|
74 |
print(f"Detected embedding dimension: {embedding_dim}")
|
75 |
|
76 |
+
# Method 1: Using custom rerank function
|
77 |
+
rag = LightRAG(
|
78 |
+
working_dir=WORKING_DIR,
|
79 |
+
llm_model_func=llm_model_func,
|
80 |
+
embedding_func=EmbeddingFunc(
|
81 |
+
embedding_dim=embedding_dim,
|
82 |
+
max_token_size=8192,
|
83 |
+
func=embedding_func,
|
84 |
+
),
|
85 |
+
# Simplified Rerank Configuration
|
86 |
+
enable_rerank=True,
|
87 |
+
rerank_model_func=my_rerank_func,
|
88 |
+
)
|
89 |
+
|
90 |
+
return rag
|
91 |
+
|
92 |
+
|
93 |
+
async def create_rag_with_rerank_model():
|
94 |
+
"""Alternative: Create LightRAG instance using RerankModel wrapper"""
|
95 |
+
|
96 |
+
# Get embedding dimension
|
97 |
+
test_embedding = await embedding_func(["test"])
|
98 |
+
embedding_dim = test_embedding.shape[1]
|
99 |
+
print(f"Detected embedding dimension: {embedding_dim}")
|
100 |
+
|
101 |
+
# Method 2: Using RerankModel wrapper
|
102 |
rerank_model = RerankModel(
|
103 |
rerank_func=custom_rerank,
|
104 |
kwargs={
|
105 |
"model": "BAAI/bge-reranker-v2-m3",
|
106 |
"base_url": "https://api.your-rerank-provider.com/v1/rerank",
|
107 |
"api_key": "your_rerank_api_key_here",
|
108 |
+
},
|
109 |
)
|
110 |
|
|
|
111 |
rag = LightRAG(
|
112 |
working_dir=WORKING_DIR,
|
113 |
llm_model_func=llm_model_func,
|
|
|
116 |
max_token_size=8192,
|
117 |
func=embedding_func,
|
118 |
),
|
|
|
119 |
enable_rerank=True,
|
120 |
rerank_model_func=rerank_model.rerank,
|
|
|
121 |
)
|
122 |
|
123 |
return rag
|
124 |
|
125 |
+
|
126 |
async def test_rerank_with_different_topk():
|
127 |
"""
|
128 |
+
Test rerank functionality with different top_k settings
|
129 |
"""
|
130 |
print("🚀 Setting up LightRAG with Rerank functionality...")
|
131 |
+
|
132 |
rag = await create_rag_with_rerank()
|
133 |
+
|
134 |
# Insert sample documents
|
135 |
sample_docs = [
|
136 |
"Reranking improves retrieval quality by re-ordering documents based on relevance.",
|
137 |
"LightRAG is a powerful retrieval-augmented generation system with multiple query modes.",
|
138 |
"Vector databases enable efficient similarity search in high-dimensional embedding spaces.",
|
139 |
"Natural language processing has evolved with large language models and transformers.",
|
140 |
+
"Machine learning algorithms can learn patterns from data without explicit programming.",
|
141 |
]
|
142 |
+
|
143 |
print("📄 Inserting sample documents...")
|
144 |
await rag.ainsert(sample_docs)
|
145 |
+
|
146 |
query = "How does reranking improve retrieval quality?"
|
147 |
print(f"\n🔍 Testing query: '{query}'")
|
148 |
print("=" * 80)
|
149 |
+
|
150 |
# Test different top_k values to show parameter priority
|
151 |
top_k_values = [2, 5, 10]
|
152 |
+
|
153 |
for top_k in top_k_values:
|
154 |
+
print(f"\n📊 Testing with QueryParam(top_k={top_k}):")
|
155 |
+
|
156 |
# Test naive mode with specific top_k
|
157 |
+
result = await rag.aquery(query, param=QueryParam(mode="naive", top_k=top_k))
|
|
|
|
|
|
|
158 |
print(f" Result length: {len(result)} characters")
|
159 |
print(f" Preview: {result[:100]}...")
|
160 |
|
161 |
+
|
162 |
async def test_direct_rerank():
|
163 |
"""Test rerank function directly"""
|
164 |
print("\n🔧 Direct Rerank API Test")
|
165 |
print("=" * 40)
|
166 |
+
|
167 |
documents = [
|
168 |
{"content": "Reranking significantly improves retrieval quality"},
|
169 |
{"content": "LightRAG supports advanced reranking capabilities"},
|
170 |
{"content": "Vector search finds semantically similar documents"},
|
171 |
{"content": "Natural language processing with modern transformers"},
|
172 |
+
{"content": "The quick brown fox jumps over the lazy dog"},
|
173 |
]
|
174 |
+
|
175 |
query = "rerank improve quality"
|
176 |
print(f"Query: '{query}'")
|
177 |
print(f"Documents: {len(documents)}")
|
178 |
+
|
179 |
try:
|
180 |
reranked_docs = await custom_rerank(
|
181 |
query=query,
|
|
|
183 |
model="BAAI/bge-reranker-v2-m3",
|
184 |
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
185 |
api_key="your_rerank_api_key_here",
|
186 |
+
top_k=3,
|
187 |
)
|
188 |
+
|
189 |
print("\n✅ Rerank Results:")
|
190 |
for i, doc in enumerate(reranked_docs):
|
191 |
score = doc.get("rerank_score", "N/A")
|
192 |
content = doc.get("content", "")[:60]
|
193 |
print(f" {i+1}. Score: {score:.4f} | {content}...")
|
194 |
+
|
195 |
except Exception as e:
|
196 |
print(f"❌ Rerank failed: {e}")
|
197 |
|
198 |
+
|
199 |
async def main():
|
200 |
"""Main example function"""
|
201 |
print("🎯 LightRAG Rerank Integration Example")
|
202 |
print("=" * 60)
|
203 |
+
|
204 |
try:
|
205 |
# Test rerank with different top_k values
|
206 |
await test_rerank_with_different_topk()
|
207 |
+
|
208 |
# Test direct rerank
|
209 |
await test_direct_rerank()
|
210 |
+
|
211 |
print("\n✅ Example completed successfully!")
|
212 |
print("\n💡 Key Points:")
|
213 |
+
print(" ✓ All rerank configurations are contained within rerank_model_func")
|
214 |
print(" ✓ Rerank improves document relevance ordering")
|
215 |
+
print(" ✓ Configure API keys within your rerank function")
|
216 |
print(" ✓ Monitor API usage and costs when using rerank services")
|
217 |
+
|
218 |
except Exception as e:
|
219 |
print(f"\n❌ Example failed: {e}")
|
220 |
import traceback
|
221 |
+
|
222 |
traceback.print_exc()
|
223 |
|
224 |
+
|
225 |
if __name__ == "__main__":
|
226 |
+
asyncio.run(main())
|
lightrag/lightrag.py
CHANGED
@@ -249,25 +249,7 @@ class LightRAG:
|
|
249 |
"""Enable reranking for improved retrieval quality. Defaults to False."""
|
250 |
|
251 |
rerank_model_func: Callable[..., object] | None = field(default=None)
|
252 |
-
"""Function for reranking retrieved documents. Optional."""
|
253 |
-
|
254 |
-
rerank_model_name: str = field(
|
255 |
-
default=os.getenv("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
|
256 |
-
)
|
257 |
-
"""Name of the rerank model used for reranking documents."""
|
258 |
-
|
259 |
-
rerank_model_max_async: int = field(default=int(os.getenv("RERANK_MAX_ASYNC", 4)))
|
260 |
-
"""Maximum number of concurrent rerank calls."""
|
261 |
-
|
262 |
-
rerank_model_kwargs: dict[str, Any] = field(default_factory=dict)
|
263 |
-
"""Additional keyword arguments passed to the rerank model function."""
|
264 |
-
|
265 |
-
rerank_top_k: int = field(default=int(os.getenv("RERANK_TOP_K", 10)))
|
266 |
-
"""Number of top documents to return after reranking.
|
267 |
-
|
268 |
-
Note: This value will be overridden by QueryParam.top_k in query calls.
|
269 |
-
Example: QueryParam(top_k=5) will override rerank_top_k=10 setting.
|
270 |
-
"""
|
271 |
|
272 |
# Storage
|
273 |
# ---
|
@@ -475,14 +457,6 @@ class LightRAG:
|
|
475 |
|
476 |
# Init Rerank
|
477 |
if self.enable_rerank and self.rerank_model_func:
|
478 |
-
self.rerank_model_func = priority_limit_async_func_call(
|
479 |
-
self.rerank_model_max_async
|
480 |
-
)(
|
481 |
-
partial(
|
482 |
-
self.rerank_model_func, # type: ignore
|
483 |
-
**self.rerank_model_kwargs,
|
484 |
-
)
|
485 |
-
)
|
486 |
logger.info("Rerank model initialized for improved retrieval quality")
|
487 |
elif self.enable_rerank and not self.rerank_model_func:
|
488 |
logger.warning(
|
|
|
249 |
"""Enable reranking for improved retrieval quality. Defaults to False."""
|
250 |
|
251 |
rerank_model_func: Callable[..., object] | None = field(default=None)
|
252 |
+
"""Function for reranking retrieved documents. All rerank configurations (model name, API keys, top_k, etc.) should be included in this function. Optional."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
# Storage
|
255 |
# ---
|
|
|
457 |
|
458 |
# Init Rerank
|
459 |
if self.enable_rerank and self.rerank_model_func:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
logger.info("Rerank model initialized for improved retrieval quality")
|
461 |
elif self.enable_rerank and not self.rerank_model_func:
|
462 |
logger.warning(
|
lightrag/operate.py
CHANGED
@@ -2864,19 +2864,15 @@ async def apply_rerank_if_enabled(
|
|
2864 |
return retrieved_docs
|
2865 |
|
2866 |
try:
|
2867 |
-
# Determine top_k for reranking
|
2868 |
-
rerank_top_k = top_k or global_config.get("rerank_top_k", 10)
|
2869 |
-
rerank_top_k = min(rerank_top_k, len(retrieved_docs))
|
2870 |
-
|
2871 |
logger.debug(
|
2872 |
-
f"Applying rerank to {len(retrieved_docs)} documents, returning top {
|
2873 |
)
|
2874 |
|
2875 |
-
# Apply reranking
|
2876 |
reranked_docs = await rerank_func(
|
2877 |
query=query,
|
2878 |
documents=retrieved_docs,
|
2879 |
-
top_k=
|
2880 |
)
|
2881 |
|
2882 |
if reranked_docs and len(reranked_docs) > 0:
|
@@ -2886,7 +2882,7 @@ async def apply_rerank_if_enabled(
|
|
2886 |
return reranked_docs
|
2887 |
else:
|
2888 |
logger.warning("Rerank returned empty results, using original documents")
|
2889 |
-
return retrieved_docs
|
2890 |
|
2891 |
except Exception as e:
|
2892 |
logger.error(f"Error during reranking: {e}, using original documents")
|
|
|
2864 |
return retrieved_docs
|
2865 |
|
2866 |
try:
|
|
|
|
|
|
|
|
|
2867 |
logger.debug(
|
2868 |
+
f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_k}"
|
2869 |
)
|
2870 |
|
2871 |
+
# Apply reranking - let rerank_model_func handle top_k internally
|
2872 |
reranked_docs = await rerank_func(
|
2873 |
query=query,
|
2874 |
documents=retrieved_docs,
|
2875 |
+
top_k=top_k,
|
2876 |
)
|
2877 |
|
2878 |
if reranked_docs and len(reranked_docs) > 0:
|
|
|
2882 |
return reranked_docs
|
2883 |
else:
|
2884 |
logger.warning("Rerank returned empty results, using original documents")
|
2885 |
+
return retrieved_docs
|
2886 |
|
2887 |
except Exception as e:
|
2888 |
logger.error(f"Error during reranking: {e}, using original documents")
|
lightrag/rerank.py
CHANGED
@@ -1,12 +1,9 @@
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import os
|
4 |
-
import json
|
5 |
import aiohttp
|
6 |
-
import numpy as np
|
7 |
from typing import Callable, Any, List, Dict, Optional
|
8 |
from pydantic import BaseModel, Field
|
9 |
-
from dataclasses import asdict
|
10 |
|
11 |
from .utils import logger
|
12 |
|
@@ -15,14 +12,17 @@ class RerankModel(BaseModel):
|
|
15 |
"""
|
16 |
Pydantic model class for defining a custom rerank model.
|
17 |
|
|
|
|
|
|
|
18 |
Attributes:
|
19 |
rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents.
|
20 |
The function should take query and documents as input and return reranked results.
|
21 |
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
22 |
-
This
|
23 |
|
24 |
Example usage:
|
25 |
-
Rerank model example
|
26 |
```python
|
27 |
rerank_model = RerankModel(
|
28 |
rerank_func=jina_rerank,
|
@@ -32,6 +32,32 @@ class RerankModel(BaseModel):
|
|
32 |
"base_url": "https://api.jina.ai/v1/rerank"
|
33 |
}
|
34 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
```
|
36 |
"""
|
37 |
|
@@ -43,25 +69,22 @@ class RerankModel(BaseModel):
|
|
43 |
query: str,
|
44 |
documents: List[Dict[str, Any]],
|
45 |
top_k: Optional[int] = None,
|
46 |
-
**extra_kwargs
|
47 |
) -> List[Dict[str, Any]]:
|
48 |
"""Rerank documents using the configured model function."""
|
49 |
# Merge extra kwargs with model kwargs
|
50 |
kwargs = {**self.kwargs, **extra_kwargs}
|
51 |
return await self.rerank_func(
|
52 |
-
query=query,
|
53 |
-
documents=documents,
|
54 |
-
top_k=top_k,
|
55 |
-
**kwargs
|
56 |
)
|
57 |
|
58 |
|
59 |
class MultiRerankModel(BaseModel):
|
60 |
"""Multiple rerank models for different modes/scenarios."""
|
61 |
-
|
62 |
# Primary rerank model (used if mode-specific models are not defined)
|
63 |
rerank_model: Optional[RerankModel] = None
|
64 |
-
|
65 |
# Mode-specific rerank models
|
66 |
entity_rerank_model: Optional[RerankModel] = None
|
67 |
relation_rerank_model: Optional[RerankModel] = None
|
@@ -73,10 +96,10 @@ class MultiRerankModel(BaseModel):
|
|
73 |
documents: List[Dict[str, Any]],
|
74 |
mode: str = "default",
|
75 |
top_k: Optional[int] = None,
|
76 |
-
**kwargs
|
77 |
) -> List[Dict[str, Any]]:
|
78 |
"""Rerank using the appropriate model based on mode."""
|
79 |
-
|
80 |
# Select model based on mode
|
81 |
if mode == "entity" and self.entity_rerank_model:
|
82 |
model = self.entity_rerank_model
|
@@ -89,7 +112,7 @@ class MultiRerankModel(BaseModel):
|
|
89 |
else:
|
90 |
logger.warning(f"No rerank model available for mode: {mode}")
|
91 |
return documents
|
92 |
-
|
93 |
return await model.rerank(query, documents, top_k, **kwargs)
|
94 |
|
95 |
|
@@ -100,11 +123,11 @@ async def generic_rerank_api(
|
|
100 |
base_url: str,
|
101 |
api_key: str,
|
102 |
top_k: Optional[int] = None,
|
103 |
-
**kwargs
|
104 |
) -> List[Dict[str, Any]]:
|
105 |
"""
|
106 |
Generic rerank function that works with Jina/Cohere compatible APIs.
|
107 |
-
|
108 |
Args:
|
109 |
query: The search query
|
110 |
documents: List of documents to rerank
|
@@ -113,43 +136,35 @@ async def generic_rerank_api(
|
|
113 |
api_key: API authentication key
|
114 |
top_k: Number of top results to return
|
115 |
**kwargs: Additional API-specific parameters
|
116 |
-
|
117 |
Returns:
|
118 |
List of reranked documents with relevance scores
|
119 |
"""
|
120 |
if not api_key:
|
121 |
logger.warning("No API key provided for rerank service")
|
122 |
return documents
|
123 |
-
|
124 |
if not documents:
|
125 |
return documents
|
126 |
-
|
127 |
# Prepare documents for reranking - handle both text and dict formats
|
128 |
prepared_docs = []
|
129 |
for doc in documents:
|
130 |
if isinstance(doc, dict):
|
131 |
# Use 'content' field if available, otherwise use 'text' or convert to string
|
132 |
-
text = doc.get(
|
133 |
else:
|
134 |
text = str(doc)
|
135 |
prepared_docs.append(text)
|
136 |
-
|
137 |
# Prepare request
|
138 |
-
headers = {
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
data = {
|
144 |
-
"model": model,
|
145 |
-
"query": query,
|
146 |
-
"documents": prepared_docs,
|
147 |
-
**kwargs
|
148 |
-
}
|
149 |
-
|
150 |
if top_k is not None:
|
151 |
data["top_k"] = min(top_k, len(prepared_docs))
|
152 |
-
|
153 |
try:
|
154 |
async with aiohttp.ClientSession() as session:
|
155 |
async with session.post(base_url, headers=headers, json=data) as response:
|
@@ -157,9 +172,9 @@ async def generic_rerank_api(
|
|
157 |
error_text = await response.text()
|
158 |
logger.error(f"Rerank API error {response.status}: {error_text}")
|
159 |
return documents
|
160 |
-
|
161 |
result = await response.json()
|
162 |
-
|
163 |
# Extract reranked results
|
164 |
if "results" in result:
|
165 |
# Standard format: results contain index and relevance_score
|
@@ -170,13 +185,15 @@ async def generic_rerank_api(
|
|
170 |
if 0 <= doc_idx < len(documents):
|
171 |
reranked_doc = documents[doc_idx].copy()
|
172 |
if "relevance_score" in item:
|
173 |
-
reranked_doc["rerank_score"] = item[
|
|
|
|
|
174 |
reranked_docs.append(reranked_doc)
|
175 |
return reranked_docs
|
176 |
else:
|
177 |
logger.warning("Unexpected rerank API response format")
|
178 |
return documents
|
179 |
-
|
180 |
except Exception as e:
|
181 |
logger.error(f"Error during reranking: {e}")
|
182 |
return documents
|
@@ -189,11 +206,11 @@ async def jina_rerank(
|
|
189 |
top_k: Optional[int] = None,
|
190 |
base_url: str = "https://api.jina.ai/v1/rerank",
|
191 |
api_key: Optional[str] = None,
|
192 |
-
**kwargs
|
193 |
) -> List[Dict[str, Any]]:
|
194 |
"""
|
195 |
Rerank documents using Jina AI API.
|
196 |
-
|
197 |
Args:
|
198 |
query: The search query
|
199 |
documents: List of documents to rerank
|
@@ -202,13 +219,13 @@ async def jina_rerank(
|
|
202 |
base_url: Jina API endpoint
|
203 |
api_key: Jina API key
|
204 |
**kwargs: Additional parameters
|
205 |
-
|
206 |
Returns:
|
207 |
List of reranked documents with relevance scores
|
208 |
"""
|
209 |
if api_key is None:
|
210 |
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
|
211 |
-
|
212 |
return await generic_rerank_api(
|
213 |
query=query,
|
214 |
documents=documents,
|
@@ -216,7 +233,7 @@ async def jina_rerank(
|
|
216 |
base_url=base_url,
|
217 |
api_key=api_key,
|
218 |
top_k=top_k,
|
219 |
-
**kwargs
|
220 |
)
|
221 |
|
222 |
|
@@ -227,11 +244,11 @@ async def cohere_rerank(
|
|
227 |
top_k: Optional[int] = None,
|
228 |
base_url: str = "https://api.cohere.ai/v1/rerank",
|
229 |
api_key: Optional[str] = None,
|
230 |
-
**kwargs
|
231 |
) -> List[Dict[str, Any]]:
|
232 |
"""
|
233 |
Rerank documents using Cohere API.
|
234 |
-
|
235 |
Args:
|
236 |
query: The search query
|
237 |
documents: List of documents to rerank
|
@@ -240,13 +257,13 @@ async def cohere_rerank(
|
|
240 |
base_url: Cohere API endpoint
|
241 |
api_key: Cohere API key
|
242 |
**kwargs: Additional parameters
|
243 |
-
|
244 |
Returns:
|
245 |
List of reranked documents with relevance scores
|
246 |
"""
|
247 |
if api_key is None:
|
248 |
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
|
249 |
-
|
250 |
return await generic_rerank_api(
|
251 |
query=query,
|
252 |
documents=documents,
|
@@ -254,7 +271,7 @@ async def cohere_rerank(
|
|
254 |
base_url=base_url,
|
255 |
api_key=api_key,
|
256 |
top_k=top_k,
|
257 |
-
**kwargs
|
258 |
)
|
259 |
|
260 |
|
@@ -266,7 +283,7 @@ async def custom_rerank(
|
|
266 |
base_url: str,
|
267 |
api_key: str,
|
268 |
top_k: Optional[int] = None,
|
269 |
-
**kwargs
|
270 |
) -> List[Dict[str, Any]]:
|
271 |
"""
|
272 |
Rerank documents using a custom API endpoint.
|
@@ -279,7 +296,7 @@ async def custom_rerank(
|
|
279 |
base_url=base_url,
|
280 |
api_key=api_key,
|
281 |
top_k=top_k,
|
282 |
-
**kwargs
|
283 |
)
|
284 |
|
285 |
|
@@ -293,15 +310,12 @@ if __name__ == "__main__":
|
|
293 |
{"content": "Tokyo is the capital of Japan."},
|
294 |
{"content": "London is the capital of England."},
|
295 |
]
|
296 |
-
|
297 |
query = "What is the capital of France?"
|
298 |
-
|
299 |
result = await jina_rerank(
|
300 |
-
query=query,
|
301 |
-
documents=docs,
|
302 |
-
top_k=2,
|
303 |
-
api_key="your-api-key-here"
|
304 |
)
|
305 |
print(result)
|
306 |
|
307 |
-
asyncio.run(main())
|
|
|
1 |
from __future__ import annotations
|
2 |
|
3 |
import os
|
|
|
4 |
import aiohttp
|
|
|
5 |
from typing import Callable, Any, List, Dict, Optional
|
6 |
from pydantic import BaseModel, Field
|
|
|
7 |
|
8 |
from .utils import logger
|
9 |
|
|
|
12 |
"""
|
13 |
Pydantic model class for defining a custom rerank model.
|
14 |
|
15 |
+
This class provides a convenient wrapper for rerank functions, allowing you to
|
16 |
+
encapsulate all rerank configurations (API keys, model settings, etc.) in one place.
|
17 |
+
|
18 |
Attributes:
|
19 |
rerank_func (Callable[[Any], List[Dict]]): A callable function that reranks documents.
|
20 |
The function should take query and documents as input and return reranked results.
|
21 |
kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
|
22 |
+
This should include all necessary configurations such as model name, API key, base_url, etc.
|
23 |
|
24 |
Example usage:
|
25 |
+
Rerank model example with Jina:
|
26 |
```python
|
27 |
rerank_model = RerankModel(
|
28 |
rerank_func=jina_rerank,
|
|
|
32 |
"base_url": "https://api.jina.ai/v1/rerank"
|
33 |
}
|
34 |
)
|
35 |
+
|
36 |
+
# Use in LightRAG
|
37 |
+
rag = LightRAG(
|
38 |
+
enable_rerank=True,
|
39 |
+
rerank_model_func=rerank_model.rerank,
|
40 |
+
# ... other configurations
|
41 |
+
)
|
42 |
+
```
|
43 |
+
|
44 |
+
Or define a custom function directly:
|
45 |
+
```python
|
46 |
+
async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
|
47 |
+
return await jina_rerank(
|
48 |
+
query=query,
|
49 |
+
documents=documents,
|
50 |
+
model="BAAI/bge-reranker-v2-m3",
|
51 |
+
api_key="your_api_key_here",
|
52 |
+
top_k=top_k or 10,
|
53 |
+
**kwargs
|
54 |
+
)
|
55 |
+
|
56 |
+
rag = LightRAG(
|
57 |
+
enable_rerank=True,
|
58 |
+
rerank_model_func=my_rerank_func,
|
59 |
+
# ... other configurations
|
60 |
+
)
|
61 |
```
|
62 |
"""
|
63 |
|
|
|
69 |
query: str,
|
70 |
documents: List[Dict[str, Any]],
|
71 |
top_k: Optional[int] = None,
|
72 |
+
**extra_kwargs,
|
73 |
) -> List[Dict[str, Any]]:
|
74 |
"""Rerank documents using the configured model function."""
|
75 |
# Merge extra kwargs with model kwargs
|
76 |
kwargs = {**self.kwargs, **extra_kwargs}
|
77 |
return await self.rerank_func(
|
78 |
+
query=query, documents=documents, top_k=top_k, **kwargs
|
|
|
|
|
|
|
79 |
)
|
80 |
|
81 |
|
82 |
class MultiRerankModel(BaseModel):
|
83 |
"""Multiple rerank models for different modes/scenarios."""
|
84 |
+
|
85 |
# Primary rerank model (used if mode-specific models are not defined)
|
86 |
rerank_model: Optional[RerankModel] = None
|
87 |
+
|
88 |
# Mode-specific rerank models
|
89 |
entity_rerank_model: Optional[RerankModel] = None
|
90 |
relation_rerank_model: Optional[RerankModel] = None
|
|
|
96 |
documents: List[Dict[str, Any]],
|
97 |
mode: str = "default",
|
98 |
top_k: Optional[int] = None,
|
99 |
+
**kwargs,
|
100 |
) -> List[Dict[str, Any]]:
|
101 |
"""Rerank using the appropriate model based on mode."""
|
102 |
+
|
103 |
# Select model based on mode
|
104 |
if mode == "entity" and self.entity_rerank_model:
|
105 |
model = self.entity_rerank_model
|
|
|
112 |
else:
|
113 |
logger.warning(f"No rerank model available for mode: {mode}")
|
114 |
return documents
|
115 |
+
|
116 |
return await model.rerank(query, documents, top_k, **kwargs)
|
117 |
|
118 |
|
|
|
123 |
base_url: str,
|
124 |
api_key: str,
|
125 |
top_k: Optional[int] = None,
|
126 |
+
**kwargs,
|
127 |
) -> List[Dict[str, Any]]:
|
128 |
"""
|
129 |
Generic rerank function that works with Jina/Cohere compatible APIs.
|
130 |
+
|
131 |
Args:
|
132 |
query: The search query
|
133 |
documents: List of documents to rerank
|
|
|
136 |
api_key: API authentication key
|
137 |
top_k: Number of top results to return
|
138 |
**kwargs: Additional API-specific parameters
|
139 |
+
|
140 |
Returns:
|
141 |
List of reranked documents with relevance scores
|
142 |
"""
|
143 |
if not api_key:
|
144 |
logger.warning("No API key provided for rerank service")
|
145 |
return documents
|
146 |
+
|
147 |
if not documents:
|
148 |
return documents
|
149 |
+
|
150 |
# Prepare documents for reranking - handle both text and dict formats
|
151 |
prepared_docs = []
|
152 |
for doc in documents:
|
153 |
if isinstance(doc, dict):
|
154 |
# Use 'content' field if available, otherwise use 'text' or convert to string
|
155 |
+
text = doc.get("content") or doc.get("text") or str(doc)
|
156 |
else:
|
157 |
text = str(doc)
|
158 |
prepared_docs.append(text)
|
159 |
+
|
160 |
# Prepare request
|
161 |
+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
162 |
+
|
163 |
+
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
|
164 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
if top_k is not None:
|
166 |
data["top_k"] = min(top_k, len(prepared_docs))
|
167 |
+
|
168 |
try:
|
169 |
async with aiohttp.ClientSession() as session:
|
170 |
async with session.post(base_url, headers=headers, json=data) as response:
|
|
|
172 |
error_text = await response.text()
|
173 |
logger.error(f"Rerank API error {response.status}: {error_text}")
|
174 |
return documents
|
175 |
+
|
176 |
result = await response.json()
|
177 |
+
|
178 |
# Extract reranked results
|
179 |
if "results" in result:
|
180 |
# Standard format: results contain index and relevance_score
|
|
|
185 |
if 0 <= doc_idx < len(documents):
|
186 |
reranked_doc = documents[doc_idx].copy()
|
187 |
if "relevance_score" in item:
|
188 |
+
reranked_doc["rerank_score"] = item[
|
189 |
+
"relevance_score"
|
190 |
+
]
|
191 |
reranked_docs.append(reranked_doc)
|
192 |
return reranked_docs
|
193 |
else:
|
194 |
logger.warning("Unexpected rerank API response format")
|
195 |
return documents
|
196 |
+
|
197 |
except Exception as e:
|
198 |
logger.error(f"Error during reranking: {e}")
|
199 |
return documents
|
|
|
206 |
top_k: Optional[int] = None,
|
207 |
base_url: str = "https://api.jina.ai/v1/rerank",
|
208 |
api_key: Optional[str] = None,
|
209 |
+
**kwargs,
|
210 |
) -> List[Dict[str, Any]]:
|
211 |
"""
|
212 |
Rerank documents using Jina AI API.
|
213 |
+
|
214 |
Args:
|
215 |
query: The search query
|
216 |
documents: List of documents to rerank
|
|
|
219 |
base_url: Jina API endpoint
|
220 |
api_key: Jina API key
|
221 |
**kwargs: Additional parameters
|
222 |
+
|
223 |
Returns:
|
224 |
List of reranked documents with relevance scores
|
225 |
"""
|
226 |
if api_key is None:
|
227 |
api_key = os.getenv("JINA_API_KEY") or os.getenv("RERANK_API_KEY")
|
228 |
+
|
229 |
return await generic_rerank_api(
|
230 |
query=query,
|
231 |
documents=documents,
|
|
|
233 |
base_url=base_url,
|
234 |
api_key=api_key,
|
235 |
top_k=top_k,
|
236 |
+
**kwargs,
|
237 |
)
|
238 |
|
239 |
|
|
|
244 |
top_k: Optional[int] = None,
|
245 |
base_url: str = "https://api.cohere.ai/v1/rerank",
|
246 |
api_key: Optional[str] = None,
|
247 |
+
**kwargs,
|
248 |
) -> List[Dict[str, Any]]:
|
249 |
"""
|
250 |
Rerank documents using Cohere API.
|
251 |
+
|
252 |
Args:
|
253 |
query: The search query
|
254 |
documents: List of documents to rerank
|
|
|
257 |
base_url: Cohere API endpoint
|
258 |
api_key: Cohere API key
|
259 |
**kwargs: Additional parameters
|
260 |
+
|
261 |
Returns:
|
262 |
List of reranked documents with relevance scores
|
263 |
"""
|
264 |
if api_key is None:
|
265 |
api_key = os.getenv("COHERE_API_KEY") or os.getenv("RERANK_API_KEY")
|
266 |
+
|
267 |
return await generic_rerank_api(
|
268 |
query=query,
|
269 |
documents=documents,
|
|
|
271 |
base_url=base_url,
|
272 |
api_key=api_key,
|
273 |
top_k=top_k,
|
274 |
+
**kwargs,
|
275 |
)
|
276 |
|
277 |
|
|
|
283 |
base_url: str,
|
284 |
api_key: str,
|
285 |
top_k: Optional[int] = None,
|
286 |
+
**kwargs,
|
287 |
) -> List[Dict[str, Any]]:
|
288 |
"""
|
289 |
Rerank documents using a custom API endpoint.
|
|
|
296 |
base_url=base_url,
|
297 |
api_key=api_key,
|
298 |
top_k=top_k,
|
299 |
+
**kwargs,
|
300 |
)
|
301 |
|
302 |
|
|
|
310 |
{"content": "Tokyo is the capital of Japan."},
|
311 |
{"content": "London is the capital of England."},
|
312 |
]
|
313 |
+
|
314 |
query = "What is the capital of France?"
|
315 |
+
|
316 |
result = await jina_rerank(
|
317 |
+
query=query, documents=docs, top_k=2, api_key="your-api-key-here"
|
|
|
|
|
|
|
318 |
)
|
319 |
print(result)
|
320 |
|
321 |
+
asyncio.run(main())
|