zrguo
commited on
Commit
·
b4d24bd
1
Parent(s):
0b197bc
fix chunk_top_k limiting
Browse files- examples/rerank_example.py +7 -0
- lightrag/operate.py +13 -4
examples/rerank_example.py
CHANGED
@@ -20,6 +20,7 @@ from lightrag import LightRAG, QueryParam
|
|
20 |
from lightrag.rerank import custom_rerank, RerankModel
|
21 |
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
22 |
from lightrag.utils import EmbeddingFunc, setup_logger
|
|
|
23 |
|
24 |
# Set up your working directory
|
25 |
WORKING_DIR = "./test_rerank"
|
@@ -87,6 +88,9 @@ async def create_rag_with_rerank():
|
|
87 |
rerank_model_func=my_rerank_func,
|
88 |
)
|
89 |
|
|
|
|
|
|
|
90 |
return rag
|
91 |
|
92 |
|
@@ -120,6 +124,9 @@ async def create_rag_with_rerank_model():
|
|
120 |
rerank_model_func=rerank_model.rerank,
|
121 |
)
|
122 |
|
|
|
|
|
|
|
123 |
return rag
|
124 |
|
125 |
|
|
|
20 |
from lightrag.rerank import custom_rerank, RerankModel
|
21 |
from lightrag.llm.openai import openai_complete_if_cache, openai_embed
|
22 |
from lightrag.utils import EmbeddingFunc, setup_logger
|
23 |
+
from lightrag.kg.shared_storage import initialize_pipeline_status
|
24 |
|
25 |
# Set up your working directory
|
26 |
WORKING_DIR = "./test_rerank"
|
|
|
88 |
rerank_model_func=my_rerank_func,
|
89 |
)
|
90 |
|
91 |
+
await rag.initialize_storages()
|
92 |
+
await initialize_pipeline_status()
|
93 |
+
|
94 |
return rag
|
95 |
|
96 |
|
|
|
124 |
rerank_model_func=rerank_model.rerank,
|
125 |
)
|
126 |
|
127 |
+
await rag.initialize_storages()
|
128 |
+
await initialize_pipeline_status()
|
129 |
+
|
130 |
return rag
|
131 |
|
132 |
|
lightrag/operate.py
CHANGED
@@ -2823,8 +2823,9 @@ async def apply_rerank_if_enabled(
|
|
2823 |
documents=retrieved_docs,
|
2824 |
top_k=top_k,
|
2825 |
)
|
2826 |
-
|
2827 |
if reranked_docs and len(reranked_docs) > 0:
|
|
|
|
|
2828 |
logger.info(
|
2829 |
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
|
2830 |
)
|
@@ -2846,7 +2847,7 @@ async def process_chunks_unified(
|
|
2846 |
source_type: str = "mixed",
|
2847 |
) -> list[dict]:
|
2848 |
"""
|
2849 |
-
Unified processing for text chunks: deduplication, reranking, and token truncation.
|
2850 |
|
2851 |
Args:
|
2852 |
query: Search query for reranking
|
@@ -2874,7 +2875,15 @@ async def process_chunks_unified(
|
|
2874 |
f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
|
2875 |
)
|
2876 |
|
2877 |
-
# 2. Apply
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2878 |
if global_config.get("enable_rerank", False) and query and unique_chunks:
|
2879 |
rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
|
2880 |
unique_chunks = await apply_rerank_if_enabled(
|
@@ -2885,7 +2894,7 @@ async def process_chunks_unified(
|
|
2885 |
)
|
2886 |
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
|
2887 |
|
2888 |
-
#
|
2889 |
tokenizer = global_config.get("tokenizer")
|
2890 |
if tokenizer and unique_chunks:
|
2891 |
original_count = len(unique_chunks)
|
|
|
2823 |
documents=retrieved_docs,
|
2824 |
top_k=top_k,
|
2825 |
)
|
|
|
2826 |
if reranked_docs and len(reranked_docs) > 0:
|
2827 |
+
if len(reranked_docs) > top_k:
|
2828 |
+
reranked_docs = reranked_docs[:top_k]
|
2829 |
logger.info(
|
2830 |
f"Successfully reranked {len(retrieved_docs)} documents to {len(reranked_docs)}"
|
2831 |
)
|
|
|
2847 |
source_type: str = "mixed",
|
2848 |
) -> list[dict]:
|
2849 |
"""
|
2850 |
+
Unified processing for text chunks: deduplication, chunk_top_k limiting, reranking, and token truncation.
|
2851 |
|
2852 |
Args:
|
2853 |
query: Search query for reranking
|
|
|
2875 |
f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
|
2876 |
)
|
2877 |
|
2878 |
+
# 2. Apply chunk_top_k limiting if specified
|
2879 |
+
if query_param.chunk_top_k is not None and query_param.chunk_top_k > 0:
|
2880 |
+
if len(unique_chunks) > query_param.chunk_top_k:
|
2881 |
+
unique_chunks = unique_chunks[: query_param.chunk_top_k]
|
2882 |
+
logger.debug(
|
2883 |
+
f"Chunk top-k limiting: kept {len(unique_chunks)} chunks (chunk_top_k={query_param.chunk_top_k})"
|
2884 |
+
)
|
2885 |
+
|
2886 |
+
# 3. Apply reranking if enabled and query is provided
|
2887 |
if global_config.get("enable_rerank", False) and query and unique_chunks:
|
2888 |
rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
|
2889 |
unique_chunks = await apply_rerank_if_enabled(
|
|
|
2894 |
)
|
2895 |
logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
|
2896 |
|
2897 |
+
# 4. Token-based final truncation
|
2898 |
tokenizer = global_config.get("tokenizer")
|
2899 |
if tokenizer and unique_chunks:
|
2900 |
original_count = len(unique_chunks)
|