zrguo commited on
Commit
b4d24bd
·
1 Parent(s): 0b197bc

fix chunk_top_k limiting

Browse files
Files changed (2) hide show
  1. examples/rerank_example.py +7 -0
  2. lightrag/operate.py +13 -4
examples/rerank_example.py CHANGED
@@ -20,6 +20,7 @@ from lightrag import LightRAG, QueryParam
20
  from lightrag.rerank import custom_rerank, RerankModel
21
  from lightrag.llm.openai import openai_complete_if_cache, openai_embed
22
  from lightrag.utils import EmbeddingFunc, setup_logger
 
23
 
24
  # Set up your working directory
25
  WORKING_DIR = "./test_rerank"
@@ -87,6 +88,9 @@ async def create_rag_with_rerank():
87
  rerank_model_func=my_rerank_func,
88
  )
89
 
 
 
 
90
  return rag
91
 
92
 
@@ -120,6 +124,9 @@ async def create_rag_with_rerank_model():
120
  rerank_model_func=rerank_model.rerank,
121
  )
122
 
 
 
 
123
  return rag
124
 
125
 
 
20
  from lightrag.rerank import custom_rerank, RerankModel
21
  from lightrag.llm.openai import openai_complete_if_cache, openai_embed
22
  from lightrag.utils import EmbeddingFunc, setup_logger
23
+ from lightrag.kg.shared_storage import initialize_pipeline_status
24
 
25
  # Set up your working directory
26
  WORKING_DIR = "./test_rerank"
 
88
  rerank_model_func=my_rerank_func,
89
  )
90
 
91
+ await rag.initialize_storages()
92
+ await initialize_pipeline_status()
93
+
94
  return rag
95
 
96
 
 
124
  rerank_model_func=rerank_model.rerank,
125
  )
126
 
127
+ await rag.initialize_storages()
128
+ await initialize_pipeline_status()
129
+
130
  return rag
131
 
132
 
lightrag/operate.py CHANGED
@@ -2823,8 +2823,9 @@ async def apply_rerank_if_enabled(
2823
  documents=retrieved_docs,
2824
  top_k=top_k,
2825
  )
2826
-
2827
  if reranked_docs and len(reranked_docs) > 0:
 
 
2828
  logger.info(
2829
  f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
2830
  )
@@ -2846,7 +2847,7 @@ async def process_chunks_unified(
2846
  source_type: str = "mixed",
2847
  ) -> list[dict]:
2848
  """
2849
- Unified processing for text chunks: deduplication, reranking, and token truncation.
2850
 
2851
  Args:
2852
  query: Search query for reranking
@@ -2874,7 +2875,15 @@ async def process_chunks_unified(
2874
  f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
2875
  )
2876
 
2877
- # 2. Apply reranking if enabled and query is provided
 
 
 
 
 
 
 
 
2878
  if global_config.get("enable_rerank", False) and query and unique_chunks:
2879
  rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
2880
  unique_chunks = await apply_rerank_if_enabled(
@@ -2885,7 +2894,7 @@ async def process_chunks_unified(
2885
  )
2886
  logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
2887
 
2888
- # 3. Token-based final truncation
2889
  tokenizer = global_config.get("tokenizer")
2890
  if tokenizer and unique_chunks:
2891
  original_count = len(unique_chunks)
 
2823
  documents=retrieved_docs,
2824
  top_k=top_k,
2825
  )
 
2826
  if reranked_docs and len(reranked_docs) > 0:
2827
+ if len(reranked_docs) > top_k:
2828
+ reranked_docs = reranked_docs[:top_k]
2829
  logger.info(
2830
  f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
2831
  )
 
2847
  source_type: str = "mixed",
2848
  ) -> list[dict]:
2849
  """
2850
+ Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation.
2851
 
2852
  Args:
2853
  query: Search query for reranking
 
2875
  f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
2876
  )
2877
 
2878
+ # 2. Apply chunk_top_k limiting if specified
2879
+ if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0:
2880
+ if len(unique_chunks) > query_param.chunk_top_k:
2881
+ unique_chunks = unique_chunks[: query_param.chunk_top_k]
2882
+ logger.debug(
2883
+ f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})"
2884
+ )
2885
+
2886
+ # 3. Apply reranking if enabled and query is provided
2887
  if global_config.get("enable_rerank", False) and query and unique_chunks:
2888
  rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
2889
  unique_chunks = await apply_rerank_if_enabled(
 
2894
  )
2895
  logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
2896
 
2897
+ # 4. Token-based final truncation
2898
  tokenizer = global_config.get("tokenizer")
2899
  if tokenizer and unique_chunks:
2900
  original_count = len(unique_chunks)