zrguo commited on
Commit
f951108
·
1 Parent(s): 4225cbd

Add rerank to server

Browse files
env.example CHANGED
@@ -46,8 +46,19 @@ OLLAMA_EMULATING_MODEL_TAG=latest
46
  # HISTORY_TURNS=3
47
  # COSINE_THRESHOLD=0.2
48
  # TOP_K=60
 
49
  # CHUNK_TOP_K=5
 
 
 
 
 
50
  # CHUNK_RERANK_TOP_K=5
 
 
 
 
 
51
  # MAX_TOKEN_TEXT_CHUNK=6000
52
  # MAX_TOKEN_RELATION_DESC=4000
53
  # MAX_TOKEN_ENTITY_DESC=4000
@@ -181,6 +192,3 @@ QDRANT_URL=http://localhost:6333
181
  ### Redis
182
  REDIS_URI=redis://localhost:6379
183
  # REDIS_WORKSPACE=forced_workspace_name
184
-
185
- # Rerank Configuration
186
- ENABLE_RERANK=False
 
46
  # HISTORY_TURNS=3
47
  # COSINE_THRESHOLD=0.2
48
  # TOP_K=60
49
+ ### Number of text chunks to retrieve initially from vector search
50
  # CHUNK_TOP_K=5
51
+
52
+ ### Rerank Configuration
53
+ ### Enable rerank functionality to improve retrieval quality
54
+ # ENABLE_RERANK=False
55
+ ### Number of text chunks to keep after reranking (should be <= CHUNK_TOP_K)
56
  # CHUNK_RERANK_TOP_K=5
57
+ ### Rerank model configuration (required when ENABLE_RERANK=True)
58
+ # RERANK_MODEL=BAAI/bge-reranker-v2-m3
59
+ # RERANK_BINDING_HOST=https://api.your-rerank-provider.com/v1/rerank
60
+ # RERANK_BINDING_API_KEY=your_rerank_api_key_here
61
+
62
  # MAX_TOKEN_TEXT_CHUNK=6000
63
  # MAX_TOKEN_RELATION_DESC=4000
64
  # MAX_TOKEN_ENTITY_DESC=4000
 
192
  ### Redis
193
  REDIS_URI=redis://localhost:6379
194
  # REDIS_WORKSPACE=forced_workspace_name
 
 
 
lightrag/api/config.py CHANGED
@@ -165,6 +165,24 @@ def parse_args() -> argparse.Namespace:
165
  default=get_env_value("TOP_K", 60, int),
166
  help="Number of most similar results to return (default: from env or 60)",
167
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  parser.add_argument(
169
  "--cosine-threshold",
170
  type=float,
@@ -295,6 +313,11 @@ def parse_args() -> argparse.Namespace:
295
  args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
296
  args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
297
 
 
 
 
 
 
298
  ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
299
 
300
  return args
 
165
  default=get_env_value("TOP_K", 60, int),
166
  help="Number of most similar results to return (default: from env or 60)",
167
  )
168
+ parser.add_argument(
169
+ "--chunk-top-k",
170
+ type=int,
171
+ default=get_env_value("CHUNK_TOP_K", 5, int),
172
+ help="Number of text chunks to retrieve initially from vector search (default: from env or 5)",
173
+ )
174
+ parser.add_argument(
175
+ "--chunk-rerank-top-k",
176
+ type=int,
177
+ default=get_env_value("CHUNK_RERANK_TOP_K", 5, int),
178
+ help="Number of text chunks to keep after reranking (default: from env or 5)",
179
+ )
180
+ parser.add_argument(
181
+ "--enable-rerank",
182
+ action="store_true",
183
+ default=get_env_value("ENABLE_RERANK", False, bool),
184
+ help="Enable rerank functionality (default: from env or False)",
185
+ )
186
  parser.add_argument(
187
  "--cosine-threshold",
188
  type=float,
 
313
  args.guest_token_expire_hours = get_env_value("GUEST_TOKEN_EXPIRE_HOURS", 24, int)
314
  args.jwt_algorithm = get_env_value("JWT_ALGORITHM", "HS256")
315
 
316
+ # Rerank model configuration
317
+ args.rerank_model = get_env_value("RERANK_MODEL", "BAAI/bge-reranker-v2-m3")
318
+ args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
319
+ args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
320
+
321
  ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
322
 
323
  return args
lightrag/api/lightrag_server.py CHANGED
@@ -291,6 +291,32 @@ def create_app(args):
291
  ),
292
  )
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  # Initialize RAG
295
  if args.llm_binding in ["lollms", "ollama", "openai"]:
296
  rag = LightRAG(
@@ -324,6 +350,8 @@ def create_app(args):
324
  },
325
  enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
326
  enable_llm_cache=args.enable_llm_cache,
 
 
327
  auto_manage_storages_states=False,
328
  max_parallel_insert=args.max_parallel_insert,
329
  max_graph_nodes=args.max_graph_nodes,
@@ -352,6 +380,8 @@ def create_app(args):
352
  },
353
  enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
354
  enable_llm_cache=args.enable_llm_cache,
 
 
355
  auto_manage_storages_states=False,
356
  max_parallel_insert=args.max_parallel_insert,
357
  max_graph_nodes=args.max_graph_nodes,
@@ -478,6 +508,12 @@ def create_app(args):
478
  "enable_llm_cache": args.enable_llm_cache,
479
  "workspace": args.workspace,
480
  "max_graph_nodes": args.max_graph_nodes,
 
 
 
 
 
 
481
  },
482
  "auth_mode": auth_mode,
483
  "pipeline_busy": pipeline_status.get("busy", False),
 
291
  ),
292
  )
293
 
294
+ # Configure rerank function if enabled
295
+ rerank_model_func = None
296
+ if args.enable_rerank and args.rerank_binding_api_key and args.rerank_binding_host:
297
+ from lightrag.rerank import custom_rerank
298
+
299
+ async def server_rerank_func(
300
+ query: str, documents: list, top_k: int = None, **kwargs
301
+ ):
302
+ """Server rerank function with configuration from environment variables"""
303
+ return await custom_rerank(
304
+ query=query,
305
+ documents=documents,
306
+ model=args.rerank_model,
307
+ base_url=args.rerank_binding_host,
308
+ api_key=args.rerank_binding_api_key,
309
+ top_k=top_k,
310
+ **kwargs,
311
+ )
312
+
313
+ rerank_model_func = server_rerank_func
314
+ logger.info(f"Rerank enabled with model: {args.rerank_model}")
315
+ elif args.enable_rerank:
316
+ logger.warning(
317
+ "Rerank enabled but RERANK_BINDING_API_KEY or RERANK_BINDING_HOST not configured. Rerank will be disabled."
318
+ )
319
+
320
  # Initialize RAG
321
  if args.llm_binding in ["lollms", "ollama", "openai"]:
322
  rag = LightRAG(
 
350
  },
351
  enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
352
  enable_llm_cache=args.enable_llm_cache,
353
+ enable_rerank=args.enable_rerank,
354
+ rerank_model_func=rerank_model_func,
355
  auto_manage_storages_states=False,
356
  max_parallel_insert=args.max_parallel_insert,
357
  max_graph_nodes=args.max_graph_nodes,
 
380
  },
381
  enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
382
  enable_llm_cache=args.enable_llm_cache,
383
+ enable_rerank=args.enable_rerank,
384
+ rerank_model_func=rerank_model_func,
385
  auto_manage_storages_states=False,
386
  max_parallel_insert=args.max_parallel_insert,
387
  max_graph_nodes=args.max_graph_nodes,
 
508
  "enable_llm_cache": args.enable_llm_cache,
509
  "workspace": args.workspace,
510
  "max_graph_nodes": args.max_graph_nodes,
511
+ # Rerank configuration
512
+ "enable_rerank": args.enable_rerank,
513
+ "rerank_model": args.rerank_model if args.enable_rerank else None,
514
+ "rerank_binding_host": args.rerank_binding_host
515
+ if args.enable_rerank
516
+ else None,
517
  },
518
  "auth_mode": auth_mode,
519
  "pipeline_busy": pipeline_status.get("busy", False),
lightrag/api/routers/query_routes.py CHANGED
@@ -49,6 +49,18 @@ class QueryRequest(BaseModel):
49
  description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
50
  )
51
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  max_token_for_text_unit: Optional[int] = Field(
53
  gt=1,
54
  default=None,
 
49
  description="Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode.",
50
  )
51
 
52
+ chunk_top_k: Optional[int] = Field(
53
+ ge=1,
54
+ default=None,
55
+ description="Number of text chunks to retrieve initially from vector search.",
56
+ )
57
+
58
+ chunk_rerank_top_k: Optional[int] = Field(
59
+ ge=1,
60
+ default=None,
61
+ description="Number of text chunks to keep after reranking.",
62
+ )
63
+
64
  max_token_for_text_unit: Optional[int] = Field(
65
  gt=1,
66
  default=None,