zrguo commited on
Commit
0b197bc
·
1 Parent(s): 24a98c3

update chunks truncation method

Browse files
Files changed (5) hide show
  1. README-zh.md +10 -0
  2. README.md +11 -1
  3. env.example +3 -1
  4. lightrag/base.py +11 -16
  5. lightrag/operate.py +176 -162
README-zh.md CHANGED
@@ -294,6 +294,16 @@ class QueryParam:
294
  top_k: int = int(os.getenv("TOP_K", "60"))
295
  """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
296
 
 
 
 
 
 
 
 
 
 
 
297
  max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
298
  """Maximum number of tokens allowed for each retrieved text chunk."""
299
 
 
294
  top_k: int = int(os.getenv("TOP_K", "60"))
295
  """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
296
 
297
+ chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
298
+ """Number of text chunks to retrieve initially from vector search.
299
+ If None, defaults to top_k value.
300
+ """
301
+
302
+ chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
303
+ """Number of text chunks to keep after reranking.
304
+ If None, keeps all chunks returned from initial retrieval.
305
+ """
306
+
307
  max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
308
  """Maximum number of tokens allowed for each retrieved text chunk."""
309
 
README.md CHANGED
@@ -153,7 +153,7 @@ curl https://raw.githubusercontent.com/gusye1234/nano-graphrag/main/tests/mock_d
153
  python examples/lightrag_openai_demo.py
154
  ```
155
 
156
- For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample codes LLM and embedding configurations accordingly.
157
 
158
  **Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory.
159
 
@@ -300,6 +300,16 @@ class QueryParam:
300
  top_k: int = int(os.getenv("TOP_K", "60"))
301
  """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
302
 
 
 
 
 
 
 
 
 
 
 
303
  max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
304
  """Maximum number of tokens allowed for each retrieved text chunk."""
305
 
 
153
  python examples/lightrag_openai_demo.py
154
  ```
155
 
156
+ For a streaming response implementation example, please see `examples/lightrag_openai_compatible_demo.py`. Prior to execution, ensure you modify the sample code's LLM and embedding configurations accordingly.
157
 
158
  **Note 1**: When running the demo program, please be aware that different test scripts may use different embedding models. If you switch to a different embedding model, you must clear the data directory (`./dickens`); otherwise, the program may encounter errors. If you wish to retain the LLM cache, you can preserve the `kv_store_llm_response_cache.json` file while clearing the data directory.
159
 
 
300
  top_k: int = int(os.getenv("TOP_K", "60"))
301
  """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
302
 
303
+ chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
304
+ """Number of text chunks to retrieve initially from vector search.
305
+ If None, defaults to top_k value.
306
+ """
307
+
308
+ chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
309
+ """Number of text chunks to keep after reranking.
310
+ If None, keeps all chunks returned from initial retrieval.
311
+ """
312
+
313
  max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
314
  """Maximum number of tokens allowed for each retrieved text chunk."""
315
 
env.example CHANGED
@@ -46,7 +46,9 @@ OLLAMA_EMULATING_MODEL_TAG=latest
46
  # HISTORY_TURNS=3
47
  # COSINE_THRESHOLD=0.2
48
  # TOP_K=60
49
- # MAX_TOKEN_TEXT_CHUNK=4000
 
 
50
  # MAX_TOKEN_RELATION_DESC=4000
51
  # MAX_TOKEN_ENTITY_DESC=4000
52
 
 
46
  # HISTORY_TURNS=3
47
  # COSINE_THRESHOLD=0.2
48
  # TOP_K=60
49
+ # CHUNK_TOP_K=5
50
+ # CHUNK_RERANK_TOP_K=5
51
+ # MAX_TOKEN_TEXT_CHUNK=6000
52
  # MAX_TOKEN_RELATION_DESC=4000
53
  # MAX_TOKEN_ENTITY_DESC=4000
54
 
lightrag/base.py CHANGED
@@ -60,7 +60,17 @@ class QueryParam:
60
  top_k: int = int(os.getenv("TOP_K", "60"))
61
  """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
62
 
63
- max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "4000"))
 
 
 
 
 
 
 
 
 
 
64
  """Maximum number of tokens allowed for each retrieved text chunk."""
65
 
66
  max_token_for_global_context: int = int(
@@ -280,21 +290,6 @@ class BaseKVStorage(StorageNameSpace, ABC):
280
  False: if the cache drop failed, or the cache mode is not supported
281
  """
282
 
283
- # async def drop_cache_by_chunk_ids(self, chunk_ids: list[str] | None = None) -> bool:
284
- # """Delete specific cache records from storage by chunk IDs
285
-
286
- # Importance notes for in-memory storage:
287
- # 1. Changes will be persisted to disk during the next index_done_callback
288
- # 2. update flags to notify other processes that data persistence is needed
289
-
290
- # Args:
291
- # chunk_ids (list[str]): List of chunk IDs to be dropped from storage
292
-
293
- # Returns:
294
- # True: if the cache drop successfully
295
- # False: if the cache drop failed, or the operation is not supported
296
- # """
297
-
298
 
299
  @dataclass
300
  class BaseGraphStorage(StorageNameSpace, ABC):
 
60
  top_k: int = int(os.getenv("TOP_K", "60"))
61
  """Number of top items to retrieve. Represents entities in 'local' mode and relationships in 'global' mode."""
62
 
63
+ chunk_top_k: int = int(os.getenv("CHUNK_TOP_K", "5"))
64
+ """Number of text chunks to retrieve initially from vector search.
65
+ If None, defaults to top_k value.
66
+ """
67
+
68
+ chunk_rerank_top_k: int = int(os.getenv("CHUNK_RERANK_TOP_K", "5"))
69
+ """Number of text chunks to keep after reranking.
70
+ If None, keeps all chunks returned from initial retrieval.
71
+ """
72
+
73
+ max_token_for_text_unit: int = int(os.getenv("MAX_TOKEN_TEXT_CHUNK", "6000"))
74
  """Maximum number of tokens allowed for each retrieved text chunk."""
75
 
76
  max_token_for_global_context: int = int(
 
290
  False: if the cache drop failed, or the cache mode is not supported
291
  """
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  @dataclass
295
  class BaseGraphStorage(StorageNameSpace, ABC):
lightrag/operate.py CHANGED
@@ -1526,6 +1526,7 @@ async def kg_query(
1526
 
1527
  # Build context
1528
  context = await _build_query_context(
 
1529
  ll_keywords_str,
1530
  hl_keywords_str,
1531
  knowledge_graph_inst,
@@ -1744,93 +1745,52 @@ async def _get_vector_context(
1744
  query: str,
1745
  chunks_vdb: BaseVectorStorage,
1746
  query_param: QueryParam,
1747
- tokenizer: Tokenizer,
1748
- ) -> tuple[list, list, list] | None:
1749
  """
1750
- Retrieve vector context from the vector database.
1751
 
1752
- This function performs vector search to find relevant text chunks for a query,
1753
- formats them with file path and creation time information.
1754
 
1755
  Args:
1756
  query: The query string to search for
1757
  chunks_vdb: Vector database containing document chunks
1758
- query_param: Query parameters including top_k and ids
1759
- tokenizer: Tokenizer for counting tokens
1760
 
1761
  Returns:
1762
- Tuple (empty_entities, empty_relations, text_units) for combine_contexts,
1763
- compatible with _get_edge_data and _get_node_data format
1764
  """
1765
  try:
1766
- results = await chunks_vdb.query(
1767
- query, top_k=query_param.top_k, ids=query_param.ids
1768
- )
 
1769
  if not results:
1770
- return [], [], []
1771
 
1772
  valid_chunks = []
1773
  for result in results:
1774
  if "content" in result:
1775
- # Directly use content from chunks_vdb.query result
1776
- chunk_with_time = {
1777
  "content": result["content"],
1778
  "created_at": result.get("created_at", None),
1779
  "file_path": result.get("file_path", "unknown_source"),
 
1780
  }
1781
- valid_chunks.append(chunk_with_time)
1782
-
1783
- if not valid_chunks:
1784
- return [], [], []
1785
-
1786
- # Apply reranking if enabled
1787
- global_config = chunks_vdb.global_config
1788
- valid_chunks = await apply_rerank_if_enabled(
1789
- query=query,
1790
- retrieved_docs=valid_chunks,
1791
- global_config=global_config,
1792
- top_k=query_param.top_k,
1793
- )
1794
-
1795
- maybe_trun_chunks = truncate_list_by_token_size(
1796
- valid_chunks,
1797
- key=lambda x: x["content"],
1798
- max_token_size=query_param.max_token_for_text_unit,
1799
- tokenizer=tokenizer,
1800
- )
1801
 
1802
  logger.debug(
1803
- f"Truncate chunks from {len(valid_chunks)} to {len(maybe_trun_chunks)} (max tokens:{query_param.max_token_for_text_unit})"
1804
- )
1805
- logger.info(
1806
- f"Query chunks: {len(maybe_trun_chunks)} chunks, top_k: {query_param.top_k}"
1807
  )
 
1808
 
1809
- if not maybe_trun_chunks:
1810
- return [], [], []
1811
-
1812
- # Create empty entities and relations contexts
1813
- entities_context = []
1814
- relations_context = []
1815
-
1816
- # Create text_units_context directly as a list of dictionaries
1817
- text_units_context = []
1818
- for i, chunk in enumerate(maybe_trun_chunks):
1819
- text_units_context.append(
1820
- {
1821
- "id": i + 1,
1822
- "content": chunk["content"],
1823
- "file_path": chunk["file_path"],
1824
- }
1825
- )
1826
-
1827
- return entities_context, relations_context, text_units_context
1828
  except Exception as e:
1829
  logger.error(f"Error in _get_vector_context: {e}")
1830
- return [], [], []
1831
 
1832
 
1833
  async def _build_query_context(
 
1834
  ll_keywords: str,
1835
  hl_keywords: str,
1836
  knowledge_graph_inst: BaseGraphStorage,
@@ -1838,27 +1798,36 @@ async def _build_query_context(
1838
  relationships_vdb: BaseVectorStorage,
1839
  text_chunks_db: BaseKVStorage,
1840
  query_param: QueryParam,
1841
- chunks_vdb: BaseVectorStorage = None, # Add chunks_vdb parameter for mix mode
1842
  ):
1843
  logger.info(f"Process {os.getpid()} building query context...")
1844
 
1845
- # Handle local and global modes as before
 
 
 
 
 
1846
  if query_param.mode == "local":
1847
- entities_context, relations_context, text_units_context = await _get_node_data(
1848
  ll_keywords,
1849
  knowledge_graph_inst,
1850
  entities_vdb,
1851
  text_chunks_db,
1852
  query_param,
1853
  )
 
 
1854
  elif query_param.mode == "global":
1855
- entities_context, relations_context, text_units_context = await _get_edge_data(
1856
  hl_keywords,
1857
  knowledge_graph_inst,
1858
  relationships_vdb,
1859
  text_chunks_db,
1860
  query_param,
1861
  )
 
 
1862
  else: # hybrid or mix mode
1863
  ll_data = await _get_node_data(
1864
  ll_keywords,
@@ -1875,61 +1844,58 @@ async def _build_query_context(
1875
  query_param,
1876
  )
1877
 
1878
- (
1879
- ll_entities_context,
1880
- ll_relations_context,
1881
- ll_text_units_context,
1882
- ) = ll_data
1883
-
1884
- (
1885
- hl_entities_context,
1886
- hl_relations_context,
1887
- hl_text_units_context,
1888
- ) = hl_data
1889
-
1890
- # Initialize vector data with empty lists
1891
- vector_entities_context, vector_relations_context, vector_text_units_context = (
1892
- [],
1893
- [],
1894
- [],
1895
- )
1896
 
1897
- # Only get vector data if in mix mode
1898
- if query_param.mode == "mix" and hasattr(query_param, "original_query"):
1899
- # Get tokenizer from text_chunks_db
1900
- tokenizer = text_chunks_db.global_config.get("tokenizer")
1901
 
1902
- # Get vector context in triple format
1903
- vector_data = await _get_vector_context(
1904
- query_param.original_query, # We need to pass the original query
 
1905
  chunks_vdb,
1906
  query_param,
1907
- tokenizer,
1908
  )
 
1909
 
1910
- # If vector_data is not None, unpack it
1911
- if vector_data is not None:
1912
- (
1913
- vector_entities_context,
1914
- vector_relations_context,
1915
- vector_text_units_context,
1916
- ) = vector_data
1917
-
1918
- # Combine and deduplicate the entities, relationships, and sources
1919
  entities_context = process_combine_contexts(
1920
- hl_entities_context, ll_entities_context, vector_entities_context
1921
  )
1922
  relations_context = process_combine_contexts(
1923
- hl_relations_context, ll_relations_context, vector_relations_context
1924
  )
1925
- text_units_context = process_combine_contexts(
1926
- hl_text_units_context, ll_text_units_context, vector_text_units_context
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1927
  )
 
 
 
 
 
1928
  # not necessary to use LLM to generate a response
1929
  if not entities_context and not relations_context:
1930
  return None
1931
 
1932
- # 转换为 JSON 字符串
1933
  entities_str = json.dumps(entities_context, ensure_ascii=False)
1934
  relations_str = json.dumps(relations_context, ensure_ascii=False)
1935
  text_units_str = json.dumps(text_units_context, ensure_ascii=False)
@@ -1975,15 +1941,6 @@ async def _get_node_data(
1975
  if not len(results):
1976
  return "", "", ""
1977
 
1978
- # Apply reranking if enabled for entity results
1979
- global_config = entities_vdb.global_config
1980
- results = await apply_rerank_if_enabled(
1981
- query=query,
1982
- retrieved_docs=results,
1983
- global_config=global_config,
1984
- top_k=query_param.top_k,
1985
- )
1986
-
1987
  # Extract all entity IDs from your results list
1988
  node_ids = [r["entity_name"] for r in results]
1989
 
@@ -2085,16 +2042,7 @@ async def _get_node_data(
2085
  }
2086
  )
2087
 
2088
- text_units_context = []
2089
- for i, t in enumerate(use_text_units):
2090
- text_units_context.append(
2091
- {
2092
- "id": i + 1,
2093
- "content": t["content"],
2094
- "file_path": t.get("file_path", "unknown_source"),
2095
- }
2096
- )
2097
- return entities_context, relations_context, text_units_context
2098
 
2099
 
2100
  async def _find_most_related_text_unit_from_entities(
@@ -2183,23 +2131,21 @@ async def _find_most_related_text_unit_from_entities(
2183
  logger.warning("No valid text units found")
2184
  return []
2185
 
2186
- tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
2187
  all_text_units = sorted(
2188
  all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
2189
  )
2190
- all_text_units = truncate_list_by_token_size(
2191
- all_text_units,
2192
- key=lambda x: x["data"]["content"],
2193
- max_token_size=query_param.max_token_for_text_unit,
2194
- tokenizer=tokenizer,
2195
- )
2196
 
2197
- logger.debug(
2198
- f"Truncate chunks from {len(all_text_units_lookup)} to {len(all_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
2199
- )
 
 
 
 
 
2200
 
2201
- all_text_units = [t["data"] for t in all_text_units]
2202
- return all_text_units
2203
 
2204
 
2205
  async def _find_most_related_edges_from_entities(
@@ -2287,15 +2233,6 @@ async def _get_edge_data(
2287
  if not len(results):
2288
  return "", "", ""
2289
 
2290
- # Apply reranking if enabled for relationship results
2291
- global_config = relationships_vdb.global_config
2292
- results = await apply_rerank_if_enabled(
2293
- query=keywords,
2294
- retrieved_docs=results,
2295
- global_config=global_config,
2296
- top_k=query_param.top_k,
2297
- )
2298
-
2299
  # Prepare edge pairs in two forms:
2300
  # For the batch edge properties function, use dicts.
2301
  edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
@@ -2510,21 +2447,16 @@ async def _find_related_text_unit_from_relationships(
2510
  logger.warning("No valid text chunks after filtering")
2511
  return []
2512
 
2513
- tokenizer: Tokenizer = text_chunks_db.global_config.get("tokenizer")
2514
- truncated_text_units = truncate_list_by_token_size(
2515
- valid_text_units,
2516
- key=lambda x: x["data"]["content"],
2517
- max_token_size=query_param.max_token_for_text_unit,
2518
- tokenizer=tokenizer,
2519
- )
2520
-
2521
- logger.debug(
2522
- f"Truncate chunks from {len(valid_text_units)} to {len(truncated_text_units)} (max tokens:{query_param.max_token_for_text_unit})"
2523
- )
2524
 
2525
- all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
 
 
 
 
 
2526
 
2527
- return all_text_units
2528
 
2529
 
2530
  async def naive_query(
@@ -2552,13 +2484,31 @@ async def naive_query(
2552
 
2553
  tokenizer: Tokenizer = global_config["tokenizer"]
2554
 
2555
- _, _, text_units_context = await _get_vector_context(
2556
- query, chunks_vdb, query_param, tokenizer
2557
- )
2558
 
2559
- if text_units_context is None or len(text_units_context) == 0:
2560
  return PROMPTS["fail_response"]
2561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2562
  text_units_str = json.dumps(text_units_context, ensure_ascii=False)
2563
  if query_param.only_need_context:
2564
  return f"""
@@ -2683,6 +2633,7 @@ async def kg_query_with_keywords(
2683
  hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
2684
 
2685
  context = await _build_query_context(
 
2686
  ll_keywords_str,
2687
  hl_keywords_str,
2688
  knowledge_graph_inst,
@@ -2805,8 +2756,6 @@ async def query_with_keywords(
2805
  f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
2806
  )
2807
 
2808
- param.original_query = query
2809
-
2810
  # Use appropriate query method based on mode
2811
  if param.mode in ["local", "global", "hybrid", "mix"]:
2812
  return await kg_query_with_keywords(
@@ -2887,3 +2836,68 @@ async def apply_rerank_if_enabled(
2887
  except Exception as e:
2888
  logger.error(f"Error during reranking: {e}, using original documents")
2889
  return retrieved_docs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1526
 
1527
  # Build context
1528
  context = await _build_query_context(
1529
+ query,
1530
  ll_keywords_str,
1531
  hl_keywords_str,
1532
  knowledge_graph_inst,
 
1745
  query: str,
1746
  chunks_vdb: BaseVectorStorage,
1747
  query_param: QueryParam,
1748
+ ) -> list[dict]:
 
1749
  """
1750
+ Retrieve text chunks from the vector database without reranking or truncation.
1751
 
1752
+ This function performs vector search to find relevant text chunks for a query.
1753
+ Reranking and truncation will be handled later in the unified processing.
1754
 
1755
  Args:
1756
  query: The query string to search for
1757
  chunks_vdb: Vector database containing document chunks
1758
+ query_param: Query parameters including chunk_top_k and ids
 
1759
 
1760
  Returns:
1761
+ List of text chunks with metadata
 
1762
  """
1763
  try:
1764
+ # Use chunk_top_k if specified, otherwise fall back to top_k
1765
+ search_top_k = query_param.chunk_top_k or query_param.top_k
1766
+
1767
+ results = await chunks_vdb.query(query, top_k=search_top_k, ids=query_param.ids)
1768
  if not results:
1769
+ return []
1770
 
1771
  valid_chunks = []
1772
  for result in results:
1773
  if "content" in result:
1774
+ chunk_with_metadata = {
 
1775
  "content": result["content"],
1776
  "created_at": result.get("created_at", None),
1777
  "file_path": result.get("file_path", "unknown_source"),
1778
+ "source_type": "vector", # Mark the source type
1779
  }
1780
+ valid_chunks.append(chunk_with_metadata)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1781
 
1782
  logger.debug(
1783
+ f"Vector search retrieved {len(valid_chunks)} chunks (top_k: {search_top_k})"
 
 
 
1784
  )
1785
+ return valid_chunks
1786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1787
  except Exception as e:
1788
  logger.error(f"Error in _get_vector_context: {e}")
1789
+ return []
1790
 
1791
 
1792
  async def _build_query_context(
1793
+ query: str,
1794
  ll_keywords: str,
1795
  hl_keywords: str,
1796
  knowledge_graph_inst: BaseGraphStorage,
 
1798
  relationships_vdb: BaseVectorStorage,
1799
  text_chunks_db: BaseKVStorage,
1800
  query_param: QueryParam,
1801
+ chunks_vdb: BaseVectorStorage = None,
1802
  ):
1803
  logger.info(f"Process {os.getpid()} building query context...")
1804
 
1805
+ # Collect all chunks from different sources
1806
+ all_chunks = []
1807
+ entities_context = []
1808
+ relations_context = []
1809
+
1810
+ # Handle local and global modes
1811
  if query_param.mode == "local":
1812
+ entities_context, relations_context, entity_chunks = await _get_node_data(
1813
  ll_keywords,
1814
  knowledge_graph_inst,
1815
  entities_vdb,
1816
  text_chunks_db,
1817
  query_param,
1818
  )
1819
+ all_chunks.extend(entity_chunks)
1820
+
1821
  elif query_param.mode == "global":
1822
+ entities_context, relations_context, relationship_chunks = await _get_edge_data(
1823
  hl_keywords,
1824
  knowledge_graph_inst,
1825
  relationships_vdb,
1826
  text_chunks_db,
1827
  query_param,
1828
  )
1829
+ all_chunks.extend(relationship_chunks)
1830
+
1831
  else: # hybrid or mix mode
1832
  ll_data = await _get_node_data(
1833
  ll_keywords,
 
1844
  query_param,
1845
  )
1846
 
1847
+ (ll_entities_context, ll_relations_context, ll_chunks) = ll_data
1848
+ (hl_entities_context, hl_relations_context, hl_chunks) = hl_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1849
 
1850
+ # Collect chunks from entity and relationship sources
1851
+ all_chunks.extend(ll_chunks)
1852
+ all_chunks.extend(hl_chunks)
 
1853
 
1854
+ # Get vector chunks if in mix mode
1855
+ if query_param.mode == "mix" and chunks_vdb:
1856
+ vector_chunks = await _get_vector_context(
1857
+ query,
1858
  chunks_vdb,
1859
  query_param,
 
1860
  )
1861
+ all_chunks.extend(vector_chunks)
1862
 
1863
+ # Combine entities and relations contexts
 
 
 
 
 
 
 
 
1864
  entities_context = process_combine_contexts(
1865
+ hl_entities_context, ll_entities_context
1866
  )
1867
  relations_context = process_combine_contexts(
1868
+ hl_relations_context, ll_relations_context
1869
  )
1870
+
1871
+ # Process all chunks uniformly: deduplication, reranking, and token truncation
1872
+ processed_chunks = await process_chunks_unified(
1873
+ query=query,
1874
+ chunks=all_chunks,
1875
+ query_param=query_param,
1876
+ global_config=text_chunks_db.global_config,
1877
+ source_type="mixed",
1878
+ )
1879
+
1880
+ # Build final text_units_context from processed chunks
1881
+ text_units_context = []
1882
+ for i, chunk in enumerate(processed_chunks):
1883
+ text_units_context.append(
1884
+ {
1885
+ "id": i + 1,
1886
+ "content": chunk["content"],
1887
+ "file_path": chunk.get("file_path", "unknown_source"),
1888
+ }
1889
  )
1890
+
1891
+ logger.info(
1892
+ f"Final context: {len(entities_context)} entities, {len(relations_context)} relations, {len(text_units_context)} chunks"
1893
+ )
1894
+
1895
  # not necessary to use LLM to generate a response
1896
  if not entities_context and not relations_context:
1897
  return None
1898
 
 
1899
  entities_str = json.dumps(entities_context, ensure_ascii=False)
1900
  relations_str = json.dumps(relations_context, ensure_ascii=False)
1901
  text_units_str = json.dumps(text_units_context, ensure_ascii=False)
 
1941
  if not len(results):
1942
  return "", "", ""
1943
 
 
 
 
 
 
 
 
 
 
1944
  # Extract all entity IDs from your results list
1945
  node_ids = [r["entity_name"] for r in results]
1946
 
 
2042
  }
2043
  )
2044
 
2045
+ return entities_context, relations_context, use_text_units
 
 
 
 
 
 
 
 
 
2046
 
2047
 
2048
  async def _find_most_related_text_unit_from_entities(
 
2131
  logger.warning("No valid text units found")
2132
  return []
2133
 
2134
+ # Sort by relation counts and order, but don't truncate
2135
  all_text_units = sorted(
2136
  all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
2137
  )
 
 
 
 
 
 
2138
 
2139
+ logger.debug(f"Found {len(all_text_units)} entity-related chunks")
2140
+
2141
+ # Add source type marking and return chunk data
2142
+ result_chunks = []
2143
+ for t in all_text_units:
2144
+ chunk_data = t["data"].copy()
2145
+ chunk_data["source_type"] = "entity"
2146
+ result_chunks.append(chunk_data)
2147
 
2148
+ return result_chunks
 
2149
 
2150
 
2151
  async def _find_most_related_edges_from_entities(
 
2233
  if not len(results):
2234
  return "", "", ""
2235
 
 
 
 
 
 
 
 
 
 
2236
  # Prepare edge pairs in two forms:
2237
  # For the batch edge properties function, use dicts.
2238
  edge_pairs_dicts = [{"src": r["src_id"], "tgt": r["tgt_id"]} for r in results]
 
2447
  logger.warning("No valid text chunks after filtering")
2448
  return []
2449
 
2450
+ logger.debug(f"Found {len(valid_text_units)} relationship-related chunks")
 
 
 
 
 
 
 
 
 
 
2451
 
2452
+ # Add source type marking and return chunk data
2453
+ result_chunks = []
2454
+ for t in valid_text_units:
2455
+ chunk_data = t["data"].copy()
2456
+ chunk_data["source_type"] = "relationship"
2457
+ result_chunks.append(chunk_data)
2458
 
2459
+ return result_chunks
2460
 
2461
 
2462
  async def naive_query(
 
2484
 
2485
  tokenizer: Tokenizer = global_config["tokenizer"]
2486
 
2487
+ chunks = await _get_vector_context(query, chunks_vdb, query_param)
 
 
2488
 
2489
+ if chunks is None or len(chunks) == 0:
2490
  return PROMPTS["fail_response"]
2491
 
2492
+ # Process chunks using unified processing
2493
+ processed_chunks = await process_chunks_unified(
2494
+ query=query,
2495
+ chunks=chunks,
2496
+ query_param=query_param,
2497
+ global_config=global_config,
2498
+ source_type="vector",
2499
+ )
2500
+
2501
+ # Build text_units_context from processed chunks
2502
+ text_units_context = []
2503
+ for i, chunk in enumerate(processed_chunks):
2504
+ text_units_context.append(
2505
+ {
2506
+ "id": i + 1,
2507
+ "content": chunk["content"],
2508
+ "file_path": chunk.get("file_path", "unknown_source"),
2509
+ }
2510
+ )
2511
+
2512
  text_units_str = json.dumps(text_units_context, ensure_ascii=False)
2513
  if query_param.only_need_context:
2514
  return f"""
 
2633
  hl_keywords_str = ", ".join(hl_keywords) if hl_keywords else ""
2634
 
2635
  context = await _build_query_context(
2636
+ query,
2637
  ll_keywords_str,
2638
  hl_keywords_str,
2639
  knowledge_graph_inst,
 
2756
  f"{prompt}\n\n### Keywords\n\n{keywords_str}\n\n### Query\n\n{query}"
2757
  )
2758
 
 
 
2759
  # Use appropriate query method based on mode
2760
  if param.mode in ["local", "global", "hybrid", "mix"]:
2761
  return await kg_query_with_keywords(
 
2836
  except Exception as e:
2837
  logger.error(f"Error during reranking: {e}, using original documents")
2838
  return retrieved_docs
2839
+
2840
+
2841
+ async def process_chunks_unified(
2842
+ query: str,
2843
+ chunks: list[dict],
2844
+ query_param: QueryParam,
2845
+ global_config: dict,
2846
+ source_type: str = "mixed",
2847
+ ) -> list[dict]:
2848
+ """
2849
+ Unified processing for text chunks: deduplication, reranking, and token truncation.
2850
+
2851
+ Args:
2852
+ query: Search query for reranking
2853
+ chunks: List of text chunks to process
2854
+ query_param: Query parameters containing configuration
2855
+ global_config: Global configuration dictionary
2856
+ source_type: Source type for logging ("vector", "entity", "relationship", "mixed")
2857
+
2858
+ Returns:
2859
+ Processed and filtered list of text chunks
2860
+ """
2861
+ if not chunks:
2862
+ return []
2863
+
2864
+ # 1. Deduplication based on content
2865
+ seen_content = set()
2866
+ unique_chunks = []
2867
+ for chunk in chunks:
2868
+ content = chunk.get("content", "")
2869
+ if content and content not in seen_content:
2870
+ seen_content.add(content)
2871
+ unique_chunks.append(chunk)
2872
+
2873
+ logger.debug(
2874
+ f"Deduplication: {len(unique_chunks)} chunks (original: {len(chunks)})"
2875
+ )
2876
+
2877
+ # 2. Apply reranking if enabled and query is provided
2878
+ if global_config.get("enable_rerank", False) and query and unique_chunks:
2879
+ rerank_top_k = query_param.chunk_rerank_top_k or len(unique_chunks)
2880
+ unique_chunks = await apply_rerank_if_enabled(
2881
+ query=query,
2882
+ retrieved_docs=unique_chunks,
2883
+ global_config=global_config,
2884
+ top_k=rerank_top_k,
2885
+ )
2886
+ logger.debug(f"Rerank: {len(unique_chunks)} chunks (source: {source_type})")
2887
+
2888
+ # 3. Token-based final truncation
2889
+ tokenizer = global_config.get("tokenizer")
2890
+ if tokenizer and unique_chunks:
2891
+ original_count = len(unique_chunks)
2892
+ unique_chunks = truncate_list_by_token_size(
2893
+ unique_chunks,
2894
+ key=lambda x: x.get("content", ""),
2895
+ max_token_size=query_param.max_token_for_text_unit,
2896
+ tokenizer=tokenizer,
2897
+ )
2898
+ logger.debug(
2899
+ f"Token truncation: {len(unique_chunks)} chunks from {original_count} "
2900
+ f"(max tokens: {query_param.max_token_for_text_unit}, source: {source_type})"
2901
+ )
2902
+
2903
+ return unique_chunks