Merge pull request #1815 from danielaskdd/rerank-top-n
Browse files- examples/rerank_example.py +3 -3
- lightrag/api/lightrag_server.py +2 -2
- lightrag/operate.py +7 -7
- lightrag/rerank.py +19 -19
examples/rerank_example.py
CHANGED
@@ -57,7 +57,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
|
57 |
)
|
58 |
|
59 |
|
60 |
-
async def my_rerank_func(query: str, documents: list,
|
61 |
"""Custom rerank function with all settings included"""
|
62 |
return await custom_rerank(
|
63 |
query=query,
|
@@ -65,7 +65,7 @@ async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwarg
|
|
65 |
model="BAAI/bge-reranker-v2-m3",
|
66 |
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
67 |
api_key="your_rerank_api_key_here",
|
68 |
-
|
69 |
**kwargs,
|
70 |
)
|
71 |
|
@@ -217,7 +217,7 @@ async def test_direct_rerank():
|
|
217 |
model="BAAI/bge-reranker-v2-m3",
|
218 |
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
219 |
api_key="your_rerank_api_key_here",
|
220 |
-
|
221 |
)
|
222 |
|
223 |
print("\n✅ Rerank Results:")
|
|
|
57 |
)
|
58 |
|
59 |
|
60 |
+
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
61 |
"""Custom rerank function with all settings included"""
|
62 |
return await custom_rerank(
|
63 |
query=query,
|
|
|
65 |
model="BAAI/bge-reranker-v2-m3",
|
66 |
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
67 |
api_key="your_rerank_api_key_here",
|
68 |
+
top_n=top_n or 10,
|
69 |
**kwargs,
|
70 |
)
|
71 |
|
|
|
217 |
model="BAAI/bge-reranker-v2-m3",
|
218 |
base_url="https://api.your-rerank-provider.com/v1/rerank",
|
219 |
api_key="your_rerank_api_key_here",
|
220 |
+
top_n=3,
|
221 |
)
|
222 |
|
223 |
print("\n✅ Rerank Results:")
|
lightrag/api/lightrag_server.py
CHANGED
@@ -298,7 +298,7 @@ def create_app(args):
|
|
298 |
from lightrag.rerank import custom_rerank
|
299 |
|
300 |
async def server_rerank_func(
|
301 |
-
query: str, documents: list,
|
302 |
):
|
303 |
"""Server rerank function with configuration from environment variables"""
|
304 |
return await custom_rerank(
|
@@ -307,7 +307,7 @@ def create_app(args):
|
|
307 |
model=args.rerank_model,
|
308 |
base_url=args.rerank_binding_host,
|
309 |
api_key=args.rerank_binding_api_key,
|
310 |
-
|
311 |
**kwargs,
|
312 |
)
|
313 |
|
|
|
298 |
from lightrag.rerank import custom_rerank
|
299 |
|
300 |
async def server_rerank_func(
|
301 |
+
query: str, documents: list, top_n: int = None, **kwargs
|
302 |
):
|
303 |
"""Server rerank function with configuration from environment variables"""
|
304 |
return await custom_rerank(
|
|
|
307 |
model=args.rerank_model,
|
308 |
base_url=args.rerank_binding_host,
|
309 |
api_key=args.rerank_binding_api_key,
|
310 |
+
top_n=top_n,
|
311 |
**kwargs,
|
312 |
)
|
313 |
|
lightrag/operate.py
CHANGED
@@ -3165,7 +3165,7 @@ async def apply_rerank_if_enabled(
|
|
3165 |
retrieved_docs: list[dict],
|
3166 |
global_config: dict,
|
3167 |
enable_rerank: bool = True,
|
3168 |
-
|
3169 |
) -> list[dict]:
|
3170 |
"""
|
3171 |
Apply reranking to retrieved documents if rerank is enabled.
|
@@ -3175,7 +3175,7 @@ async def apply_rerank_if_enabled(
|
|
3175 |
retrieved_docs: List of retrieved documents
|
3176 |
global_config: Global configuration containing rerank settings
|
3177 |
enable_rerank: Whether to enable reranking from query parameter
|
3178 |
-
|
3179 |
|
3180 |
Returns:
|
3181 |
Reranked documents if rerank is enabled, otherwise original documents
|
@@ -3192,18 +3192,18 @@ async def apply_rerank_if_enabled(
|
|
3192 |
|
3193 |
try:
|
3194 |
logger.debug(
|
3195 |
-
f"Applying rerank to {len(retrieved_docs)} documents, returning top {
|
3196 |
)
|
3197 |
|
3198 |
# Apply reranking - let rerank_model_func handle top_k internally
|
3199 |
reranked_docs = await rerank_func(
|
3200 |
query=query,
|
3201 |
documents=retrieved_docs,
|
3202 |
-
|
3203 |
)
|
3204 |
if reranked_docs and len(reranked_docs) > 0:
|
3205 |
-
if len(reranked_docs) >
|
3206 |
-
reranked_docs = reranked_docs[:
|
3207 |
logger.info(
|
3208 |
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
|
3209 |
)
|
@@ -3263,7 +3263,7 @@ async def process_chunks_unified(
|
|
3263 |
retrieved_docs=unique_chunks,
|
3264 |
global_config=global_config,
|
3265 |
enable_rerank=query_param.enable_rerank,
|
3266 |
-
|
3267 |
)
|
3268 |
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
|
3269 |
|
|
|
3165 |
retrieved_docs: list[dict],
|
3166 |
global_config: dict,
|
3167 |
enable_rerank: bool = True,
|
3168 |
+
top_n: int = None,
|
3169 |
) -> list[dict]:
|
3170 |
"""
|
3171 |
Apply reranking to retrieved documents if rerank is enabled.
|
|
|
3175 |
retrieved_docs: List of retrieved documents
|
3176 |
global_config: Global configuration containing rerank settings
|
3177 |
enable_rerank: Whether to enable reranking from query parameter
|
3178 |
+
top_n: Number of top documents to return after reranking
|
3179 |
|
3180 |
Returns:
|
3181 |
Reranked documents if rerank is enabled, otherwise original documents
|
|
|
3192 |
|
3193 |
try:
|
3194 |
logger.debug(
|
3195 |
+
f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_n}"
|
3196 |
)
|
3197 |
|
3198 |
# Apply reranking - let rerank_model_func handle top_k internally
|
3199 |
reranked_docs = await rerank_func(
|
3200 |
query=query,
|
3201 |
documents=retrieved_docs,
|
3202 |
+
top_n=top_n,
|
3203 |
)
|
3204 |
if reranked_docs and len(reranked_docs) > 0:
|
3205 |
+
if len(reranked_docs) > top_n:
|
3206 |
+
reranked_docs = reranked_docs[:top_n]
|
3207 |
logger.info(
|
3208 |
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
|
3209 |
)
|
|
|
3263 |
retrieved_docs=unique_chunks,
|
3264 |
global_config=global_config,
|
3265 |
enable_rerank=query_param.enable_rerank,
|
3266 |
+
top_n=rerank_top_k,
|
3267 |
)
|
3268 |
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
|
3269 |
|
lightrag/rerank.py
CHANGED
@@ -41,13 +41,13 @@ class RerankModel(BaseModel):
|
|
41 |
|
42 |
Or define a custom function directly:
|
43 |
```python
|
44 |
-
async def my_rerank_func(query: str, documents: list,
|
45 |
return await jina_rerank(
|
46 |
query=query,
|
47 |
documents=documents,
|
48 |
model="BAAI/bge-reranker-v2-m3",
|
49 |
api_key="your_api_key_here",
|
50 |
-
|
51 |
**kwargs
|
52 |
)
|
53 |
|
@@ -71,14 +71,14 @@ class RerankModel(BaseModel):
|
|
71 |
self,
|
72 |
query: str,
|
73 |
documents: List[Dict[str, Any]],
|
74 |
-
|
75 |
**extra_kwargs,
|
76 |
) -> List[Dict[str, Any]]:
|
77 |
"""Rerank documents using the configured model function."""
|
78 |
# Merge extra kwargs with model kwargs
|
79 |
kwargs = {**self.kwargs, **extra_kwargs}
|
80 |
return await self.rerank_func(
|
81 |
-
query=query, documents=documents,
|
82 |
)
|
83 |
|
84 |
|
@@ -98,7 +98,7 @@ class MultiRerankModel(BaseModel):
|
|
98 |
query: str,
|
99 |
documents: List[Dict[str, Any]],
|
100 |
mode: str = "default",
|
101 |
-
|
102 |
**kwargs,
|
103 |
) -> List[Dict[str, Any]]:
|
104 |
"""Rerank using the appropriate model based on mode."""
|
@@ -116,7 +116,7 @@ class MultiRerankModel(BaseModel):
|
|
116 |
logger.warning(f"No rerank model available for mode: {mode}")
|
117 |
return documents
|
118 |
|
119 |
-
return await model.rerank(query, documents,
|
120 |
|
121 |
|
122 |
async def generic_rerank_api(
|
@@ -125,7 +125,7 @@ async def generic_rerank_api(
|
|
125 |
model: str,
|
126 |
base_url: str,
|
127 |
api_key: str,
|
128 |
-
|
129 |
**kwargs,
|
130 |
) -> List[Dict[str, Any]]:
|
131 |
"""
|
@@ -137,7 +137,7 @@ async def generic_rerank_api(
|
|
137 |
model: Model identifier
|
138 |
base_url: API endpoint URL
|
139 |
api_key: API authentication key
|
140 |
-
|
141 |
**kwargs: Additional API-specific parameters
|
142 |
|
143 |
Returns:
|
@@ -165,8 +165,8 @@ async def generic_rerank_api(
|
|
165 |
|
166 |
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
|
167 |
|
168 |
-
if
|
169 |
-
data["
|
170 |
|
171 |
try:
|
172 |
async with aiohttp.ClientSession() as session:
|
@@ -206,7 +206,7 @@ async def jina_rerank(
|
|
206 |
query: str,
|
207 |
documents: List[Dict[str, Any]],
|
208 |
model: str = "BAAI/bge-reranker-v2-m3",
|
209 |
-
|
210 |
base_url: str = "https://api.jina.ai/v1/rerank",
|
211 |
api_key: Optional[str] = None,
|
212 |
**kwargs,
|
@@ -218,7 +218,7 @@ async def jina_rerank(
|
|
218 |
query: The search query
|
219 |
documents: List of documents to rerank
|
220 |
model: Jina rerank model name
|
221 |
-
|
222 |
base_url: Jina API endpoint
|
223 |
api_key: Jina API key
|
224 |
**kwargs: Additional parameters
|
@@ -235,7 +235,7 @@ async def jina_rerank(
|
|
235 |
model=model,
|
236 |
base_url=base_url,
|
237 |
api_key=api_key,
|
238 |
-
|
239 |
**kwargs,
|
240 |
)
|
241 |
|
@@ -244,7 +244,7 @@ async def cohere_rerank(
|
|
244 |
query: str,
|
245 |
documents: List[Dict[str, Any]],
|
246 |
model: str = "rerank-english-v2.0",
|
247 |
-
|
248 |
base_url: str = "https://api.cohere.ai/v1/rerank",
|
249 |
api_key: Optional[str] = None,
|
250 |
**kwargs,
|
@@ -256,7 +256,7 @@ async def cohere_rerank(
|
|
256 |
query: The search query
|
257 |
documents: List of documents to rerank
|
258 |
model: Cohere rerank model name
|
259 |
-
|
260 |
base_url: Cohere API endpoint
|
261 |
api_key: Cohere API key
|
262 |
**kwargs: Additional parameters
|
@@ -273,7 +273,7 @@ async def cohere_rerank(
|
|
273 |
model=model,
|
274 |
base_url=base_url,
|
275 |
api_key=api_key,
|
276 |
-
|
277 |
**kwargs,
|
278 |
)
|
279 |
|
@@ -285,7 +285,7 @@ async def custom_rerank(
|
|
285 |
model: str,
|
286 |
base_url: str,
|
287 |
api_key: str,
|
288 |
-
|
289 |
**kwargs,
|
290 |
) -> List[Dict[str, Any]]:
|
291 |
"""
|
@@ -298,7 +298,7 @@ async def custom_rerank(
|
|
298 |
model=model,
|
299 |
base_url=base_url,
|
300 |
api_key=api_key,
|
301 |
-
|
302 |
**kwargs,
|
303 |
)
|
304 |
|
@@ -317,7 +317,7 @@ if __name__ == "__main__":
|
|
317 |
query = "What is the capital of France?"
|
318 |
|
319 |
result = await jina_rerank(
|
320 |
-
query=query, documents=docs,
|
321 |
)
|
322 |
print(result)
|
323 |
|
|
|
41 |
|
42 |
Or define a custom function directly:
|
43 |
```python
|
44 |
+
async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
|
45 |
return await jina_rerank(
|
46 |
query=query,
|
47 |
documents=documents,
|
48 |
model="BAAI/bge-reranker-v2-m3",
|
49 |
api_key="your_api_key_here",
|
50 |
+
top_n=top_n or 10,
|
51 |
**kwargs
|
52 |
)
|
53 |
|
|
|
71 |
self,
|
72 |
query: str,
|
73 |
documents: List[Dict[str, Any]],
|
74 |
+
top_n: Optional[int] = None,
|
75 |
**extra_kwargs,
|
76 |
) -> List[Dict[str, Any]]:
|
77 |
"""Rerank documents using the configured model function."""
|
78 |
# Merge extra kwargs with model kwargs
|
79 |
kwargs = {**self.kwargs, **extra_kwargs}
|
80 |
return await self.rerank_func(
|
81 |
+
query=query, documents=documents, top_n=top_n, **kwargs
|
82 |
)
|
83 |
|
84 |
|
|
|
98 |
query: str,
|
99 |
documents: List[Dict[str, Any]],
|
100 |
mode: str = "default",
|
101 |
+
top_n: Optional[int] = None,
|
102 |
**kwargs,
|
103 |
) -> List[Dict[str, Any]]:
|
104 |
"""Rerank using the appropriate model based on mode."""
|
|
|
116 |
logger.warning(f"No rerank model available for mode: {mode}")
|
117 |
return documents
|
118 |
|
119 |
+
return await model.rerank(query, documents, top_n, **kwargs)
|
120 |
|
121 |
|
122 |
async def generic_rerank_api(
|
|
|
125 |
model: str,
|
126 |
base_url: str,
|
127 |
api_key: str,
|
128 |
+
top_n: Optional[int] = None,
|
129 |
**kwargs,
|
130 |
) -> List[Dict[str, Any]]:
|
131 |
"""
|
|
|
137 |
model: Model identifier
|
138 |
base_url: API endpoint URL
|
139 |
api_key: API authentication key
|
140 |
+
top_n: Number of top results to return
|
141 |
**kwargs: Additional API-specific parameters
|
142 |
|
143 |
Returns:
|
|
|
165 |
|
166 |
data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
|
167 |
|
168 |
+
if top_n is not None:
|
169 |
+
data["top_n"] = min(top_n, len(prepared_docs))
|
170 |
|
171 |
try:
|
172 |
async with aiohttp.ClientSession() as session:
|
|
|
206 |
query: str,
|
207 |
documents: List[Dict[str, Any]],
|
208 |
model: str = "BAAI/bge-reranker-v2-m3",
|
209 |
+
top_n: Optional[int] = None,
|
210 |
base_url: str = "https://api.jina.ai/v1/rerank",
|
211 |
api_key: Optional[str] = None,
|
212 |
**kwargs,
|
|
|
218 |
query: The search query
|
219 |
documents: List of documents to rerank
|
220 |
model: Jina rerank model name
|
221 |
+
top_n: Number of top results to return
|
222 |
base_url: Jina API endpoint
|
223 |
api_key: Jina API key
|
224 |
**kwargs: Additional parameters
|
|
|
235 |
model=model,
|
236 |
base_url=base_url,
|
237 |
api_key=api_key,
|
238 |
+
top_n=top_n,
|
239 |
**kwargs,
|
240 |
)
|
241 |
|
|
|
244 |
query: str,
|
245 |
documents: List[Dict[str, Any]],
|
246 |
model: str = "rerank-english-v2.0",
|
247 |
+
top_n: Optional[int] = None,
|
248 |
base_url: str = "https://api.cohere.ai/v1/rerank",
|
249 |
api_key: Optional[str] = None,
|
250 |
**kwargs,
|
|
|
256 |
query: The search query
|
257 |
documents: List of documents to rerank
|
258 |
model: Cohere rerank model name
|
259 |
+
top_n: Number of top results to return
|
260 |
base_url: Cohere API endpoint
|
261 |
api_key: Cohere API key
|
262 |
**kwargs: Additional parameters
|
|
|
273 |
model=model,
|
274 |
base_url=base_url,
|
275 |
api_key=api_key,
|
276 |
+
top_n=top_n,
|
277 |
**kwargs,
|
278 |
)
|
279 |
|
|
|
285 |
model: str,
|
286 |
base_url: str,
|
287 |
api_key: str,
|
288 |
+
top_n: Optional[int] = None,
|
289 |
**kwargs,
|
290 |
) -> List[Dict[str, Any]]:
|
291 |
"""
|
|
|
298 |
model=model,
|
299 |
base_url=base_url,
|
300 |
api_key=api_key,
|
301 |
+
top_n=top_n,
|
302 |
**kwargs,
|
303 |
)
|
304 |
|
|
|
317 |
query = "What is the capital of France?"
|
318 |
|
319 |
result = await jina_rerank(
|
320 |
+
query=query, documents=docs, top_n=2, api_key="your-api-key-here"
|
321 |
)
|
322 |
print(result)
|
323 |
|