Daniel.y commited on
Commit
940fee0
·
unverified ·
2 Parent(s): d4a8eb7 dcf00c8

Merge pull request #1815 from danielaskdd/rerank-top-n

Browse files
examples/rerank_example.py CHANGED
@@ -57,7 +57,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
57
  )
58
 
59
 
60
- async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
61
  """Custom rerank function with all settings included"""
62
  return await custom_rerank(
63
  query=query,
@@ -65,7 +65,7 @@ async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwarg
65
  model="BAAI/bge-reranker-v2-m3",
66
  base_url="https://api.your-rerank-provider.com/v1/rerank",
67
  api_key="your_rerank_api_key_here",
68
- top_k=top_k or 10, # Default top_k if not provided
69
  **kwargs,
70
  )
71
 
@@ -217,7 +217,7 @@ async def test_direct_rerank():
217
  model="BAAI/bge-reranker-v2-m3",
218
  base_url="https://api.your-rerank-provider.com/v1/rerank",
219
  api_key="your_rerank_api_key_here",
220
- top_k=3,
221
  )
222
 
223
  print("\n✅ Rerank Results:")
 
57
  )
58
 
59
 
60
+ async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
61
  """Custom rerank function with all settings included"""
62
  return await custom_rerank(
63
  query=query,
 
65
  model="BAAI/bge-reranker-v2-m3",
66
  base_url="https://api.your-rerank-provider.com/v1/rerank",
67
  api_key="your_rerank_api_key_here",
68
+ top_n=top_n or 10,
69
  **kwargs,
70
  )
71
 
 
217
  model="BAAI/bge-reranker-v2-m3",
218
  base_url="https://api.your-rerank-provider.com/v1/rerank",
219
  api_key="your_rerank_api_key_here",
220
+ top_n=3,
221
  )
222
 
223
  print("\n✅ Rerank Results:")
lightrag/api/lightrag_server.py CHANGED
@@ -298,7 +298,7 @@ def create_app(args):
298
  from lightrag.rerank import custom_rerank
299
 
300
  async def server_rerank_func(
301
- query: str, documents: list, top_k: int = None, **kwargs
302
  ):
303
  """Server rerank function with configuration from environment variables"""
304
  return await custom_rerank(
@@ -307,7 +307,7 @@ def create_app(args):
307
  model=args.rerank_model,
308
  base_url=args.rerank_binding_host,
309
  api_key=args.rerank_binding_api_key,
310
- top_k=top_k,
311
  **kwargs,
312
  )
313
 
 
298
  from lightrag.rerank import custom_rerank
299
 
300
  async def server_rerank_func(
301
+ query: str, documents: list, top_n: int = None, **kwargs
302
  ):
303
  """Server rerank function with configuration from environment variables"""
304
  return await custom_rerank(
 
307
  model=args.rerank_model,
308
  base_url=args.rerank_binding_host,
309
  api_key=args.rerank_binding_api_key,
310
+ top_n=top_n,
311
  **kwargs,
312
  )
313
 
lightrag/operate.py CHANGED
@@ -3165,7 +3165,7 @@ async def apply_rerank_if_enabled(
3165
  retrieved_docs: list[dict],
3166
  global_config: dict,
3167
  enable_rerank: bool = True,
3168
- top_k: int = None,
3169
  ) -> list[dict]:
3170
  """
3171
  Apply reranking to retrieved documents if rerank is enabled.
@@ -3175,7 +3175,7 @@ async def apply_rerank_if_enabled(
3175
  retrieved_docs: List of retrieved documents
3176
  global_config: Global configuration containing rerank settings
3177
  enable_rerank: Whether to enable reranking from query parameter
3178
- top_k: Number of top documents to return after reranking
3179
 
3180
  Returns:
3181
  Reranked documents if rerank is enabled, otherwise original documents
@@ -3192,18 +3192,18 @@ async def apply_rerank_if_enabled(
3192
 
3193
  try:
3194
  logger.debug(
3195
- f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_k}"
3196
  )
3197
 
3198
  # Apply reranking - let rerank_model_func handle top_k internally
3199
  reranked_docs = await rerank_func(
3200
  query=query,
3201
  documents=retrieved_docs,
3202
- top_k=top_k,
3203
  )
3204
  if reranked_docs and len(reranked_docs) > 0:
3205
- if len(reranked_docs) > top_k:
3206
- reranked_docs = reranked_docs[:top_k]
3207
  logger.info(
3208
  f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
3209
  )
@@ -3263,7 +3263,7 @@ async def process_chunks_unified(
3263
  retrieved_docs=unique_chunks,
3264
  global_config=global_config,
3265
  enable_rerank=query_param.enable_rerank,
3266
- top_k=rerank_top_k,
3267
  )
3268
  logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
3269
 
 
3165
  retrieved_docs: list[dict],
3166
  global_config: dict,
3167
  enable_rerank: bool = True,
3168
+ top_n: int = None,
3169
  ) -> list[dict]:
3170
  """
3171
  Apply reranking to retrieved documents if rerank is enabled.
 
3175
  retrieved_docs: List of retrieved documents
3176
  global_config: Global configuration containing rerank settings
3177
  enable_rerank: Whether to enable reranking from query parameter
3178
+ top_n: Number of top documents to return after reranking
3179
 
3180
  Returns:
3181
  Reranked documents if rerank is enabled, otherwise original documents
 
3192
 
3193
  try:
3194
  logger.debug(
3195
+ f"Applying rerank to {len(retrieved_docs)} documents, returning top {top_n}"
3196
  )
3197
 
3198
  # Apply reranking - let rerank_model_func handle top_k internally
3199
  reranked_docs = await rerank_func(
3200
  query=query,
3201
  documents=retrieved_docs,
3202
+ top_n=top_n,
3203
  )
3204
  if reranked_docs and len(reranked_docs) > 0:
3205
+ if len(reranked_docs) > top_n:
3206
+ reranked_docs = reranked_docs[:top_n]
3207
  logger.info(
3208
  f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
3209
  )
 
3263
  retrieved_docs=unique_chunks,
3264
  global_config=global_config,
3265
  enable_rerank=query_param.enable_rerank,
3266
+ top_n=rerank_top_k,
3267
  )
3268
  logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
3269
 
lightrag/rerank.py CHANGED
@@ -41,13 +41,13 @@ class RerankModel(BaseModel):
41
 
42
  Or define a custom function directly:
43
  ```python
44
- async def my_rerank_func(query: str, documents: list, top_k: int = None, **kwargs):
45
  return await jina_rerank(
46
  query=query,
47
  documents=documents,
48
  model="BAAI/bge-reranker-v2-m3",
49
  api_key="your_api_key_here",
50
- top_k=top_k or 10,
51
  **kwargs
52
  )
53
 
@@ -71,14 +71,14 @@ class RerankModel(BaseModel):
71
  self,
72
  query: str,
73
  documents: List[Dict[str, Any]],
74
- top_k: Optional[int] = None,
75
  **extra_kwargs,
76
  ) -> List[Dict[str, Any]]:
77
  """Rerank documents using the configured model function."""
78
  # Merge extra kwargs with model kwargs
79
  kwargs = {**self.kwargs, **extra_kwargs}
80
  return await self.rerank_func(
81
- query=query, documents=documents, top_k=top_k, **kwargs
82
  )
83
 
84
 
@@ -98,7 +98,7 @@ class MultiRerankModel(BaseModel):
98
  query: str,
99
  documents: List[Dict[str, Any]],
100
  mode: str = "default",
101
- top_k: Optional[int] = None,
102
  **kwargs,
103
  ) -> List[Dict[str, Any]]:
104
  """Rerank using the appropriate model based on mode."""
@@ -116,7 +116,7 @@ class MultiRerankModel(BaseModel):
116
  logger.warning(f"No rerank model available for mode: {mode}")
117
  return documents
118
 
119
- return await model.rerank(query, documents, top_k, **kwargs)
120
 
121
 
122
  async def generic_rerank_api(
@@ -125,7 +125,7 @@ async def generic_rerank_api(
125
  model: str,
126
  base_url: str,
127
  api_key: str,
128
- top_k: Optional[int] = None,
129
  **kwargs,
130
  ) -> List[Dict[str, Any]]:
131
  """
@@ -137,7 +137,7 @@ async def generic_rerank_api(
137
  model: Model identifier
138
  base_url: API endpoint URL
139
  api_key: API authentication key
140
- top_k: Number of top results to return
141
  **kwargs: Additional API-specific parameters
142
 
143
  Returns:
@@ -165,8 +165,8 @@ async def generic_rerank_api(
165
 
166
  data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
167
 
168
- if top_k is not None:
169
- data["top_k"] = min(top_k, len(prepared_docs))
170
 
171
  try:
172
  async with aiohttp.ClientSession() as session:
@@ -206,7 +206,7 @@ async def jina_rerank(
206
  query: str,
207
  documents: List[Dict[str, Any]],
208
  model: str = "BAAI/bge-reranker-v2-m3",
209
- top_k: Optional[int] = None,
210
  base_url: str = "https://api.jina.ai/v1/rerank",
211
  api_key: Optional[str] = None,
212
  **kwargs,
@@ -218,7 +218,7 @@ async def jina_rerank(
218
  query: The search query
219
  documents: List of documents to rerank
220
  model: Jina rerank model name
221
- top_k: Number of top results to return
222
  base_url: Jina API endpoint
223
  api_key: Jina API key
224
  **kwargs: Additional parameters
@@ -235,7 +235,7 @@ async def jina_rerank(
235
  model=model,
236
  base_url=base_url,
237
  api_key=api_key,
238
- top_k=top_k,
239
  **kwargs,
240
  )
241
 
@@ -244,7 +244,7 @@ async def cohere_rerank(
244
  query: str,
245
  documents: List[Dict[str, Any]],
246
  model: str = "rerank-english-v2.0",
247
- top_k: Optional[int] = None,
248
  base_url: str = "https://api.cohere.ai/v1/rerank",
249
  api_key: Optional[str] = None,
250
  **kwargs,
@@ -256,7 +256,7 @@ async def cohere_rerank(
256
  query: The search query
257
  documents: List of documents to rerank
258
  model: Cohere rerank model name
259
- top_k: Number of top results to return
260
  base_url: Cohere API endpoint
261
  api_key: Cohere API key
262
  **kwargs: Additional parameters
@@ -273,7 +273,7 @@ async def cohere_rerank(
273
  model=model,
274
  base_url=base_url,
275
  api_key=api_key,
276
- top_k=top_k,
277
  **kwargs,
278
  )
279
 
@@ -285,7 +285,7 @@ async def custom_rerank(
285
  model: str,
286
  base_url: str,
287
  api_key: str,
288
- top_k: Optional[int] = None,
289
  **kwargs,
290
  ) -> List[Dict[str, Any]]:
291
  """
@@ -298,7 +298,7 @@ async def custom_rerank(
298
  model=model,
299
  base_url=base_url,
300
  api_key=api_key,
301
- top_k=top_k,
302
  **kwargs,
303
  )
304
 
@@ -317,7 +317,7 @@ if __name__ == "__main__":
317
  query = "What is the capital of France?"
318
 
319
  result = await jina_rerank(
320
- query=query, documents=docs, top_k=2, api_key="your-api-key-here"
321
  )
322
  print(result)
323
 
 
41
 
42
  Or define a custom function directly:
43
  ```python
44
+ async def my_rerank_func(query: str, documents: list, top_n: int = None, **kwargs):
45
  return await jina_rerank(
46
  query=query,
47
  documents=documents,
48
  model="BAAI/bge-reranker-v2-m3",
49
  api_key="your_api_key_here",
50
+ top_n=top_n or 10,
51
  **kwargs
52
  )
53
 
 
71
  self,
72
  query: str,
73
  documents: List[Dict[str, Any]],
74
+ top_n: Optional[int] = None,
75
  **extra_kwargs,
76
  ) -> List[Dict[str, Any]]:
77
  """Rerank documents using the configured model function."""
78
  # Merge extra kwargs with model kwargs
79
  kwargs = {**self.kwargs, **extra_kwargs}
80
  return await self.rerank_func(
81
+ query=query, documents=documents, top_n=top_n, **kwargs
82
  )
83
 
84
 
 
98
  query: str,
99
  documents: List[Dict[str, Any]],
100
  mode: str = "default",
101
+ top_n: Optional[int] = None,
102
  **kwargs,
103
  ) -> List[Dict[str, Any]]:
104
  """Rerank using the appropriate model based on mode."""
 
116
  logger.warning(f"No rerank model available for mode: {mode}")
117
  return documents
118
 
119
+ return await model.rerank(query, documents, top_n, **kwargs)
120
 
121
 
122
  async def generic_rerank_api(
 
125
  model: str,
126
  base_url: str,
127
  api_key: str,
128
+ top_n: Optional[int] = None,
129
  **kwargs,
130
  ) -> List[Dict[str, Any]]:
131
  """
 
137
  model: Model identifier
138
  base_url: API endpoint URL
139
  api_key: API authentication key
140
+ top_n: Number of top results to return
141
  **kwargs: Additional API-specific parameters
142
 
143
  Returns:
 
165
 
166
  data = {"model": model, "query": query, "documents": prepared_docs, **kwargs}
167
 
168
+ if top_n is not None:
169
+ data["top_n"] = min(top_n, len(prepared_docs))
170
 
171
  try:
172
  async with aiohttp.ClientSession() as session:
 
206
  query: str,
207
  documents: List[Dict[str, Any]],
208
  model: str = "BAAI/bge-reranker-v2-m3",
209
+ top_n: Optional[int] = None,
210
  base_url: str = "https://api.jina.ai/v1/rerank",
211
  api_key: Optional[str] = None,
212
  **kwargs,
 
218
  query: The search query
219
  documents: List of documents to rerank
220
  model: Jina rerank model name
221
+ top_n: Number of top results to return
222
  base_url: Jina API endpoint
223
  api_key: Jina API key
224
  **kwargs: Additional parameters
 
235
  model=model,
236
  base_url=base_url,
237
  api_key=api_key,
238
+ top_n=top_n,
239
  **kwargs,
240
  )
241
 
 
244
  query: str,
245
  documents: List[Dict[str, Any]],
246
  model: str = "rerank-english-v2.0",
247
+ top_n: Optional[int] = None,
248
  base_url: str = "https://api.cohere.ai/v1/rerank",
249
  api_key: Optional[str] = None,
250
  **kwargs,
 
256
  query: The search query
257
  documents: List of documents to rerank
258
  model: Cohere rerank model name
259
+ top_n: Number of top results to return
260
  base_url: Cohere API endpoint
261
  api_key: Cohere API key
262
  **kwargs: Additional parameters
 
273
  model=model,
274
  base_url=base_url,
275
  api_key=api_key,
276
+ top_n=top_n,
277
  **kwargs,
278
  )
279
 
 
285
  model: str,
286
  base_url: str,
287
  api_key: str,
288
+ top_n: Optional[int] = None,
289
  **kwargs,
290
  ) -> List[Dict[str, Any]]:
291
  """
 
298
  model=model,
299
  base_url=base_url,
300
  api_key=api_key,
301
+ top_n=top_n,
302
  **kwargs,
303
  )
304
 
 
317
  query = "What is the capital of France?"
318
 
319
  result = await jina_rerank(
320
+ query=query, documents=docs, top_n=2, api_key="your-api-key-here"
321
  )
322
  print(result)
323