gzdaniel commited on
Commit
436d92f
·
1 Parent(s): c3b7212

Centralize query parameters into LightRAG class

Browse files

This commit refactors query parameter management by consolidating settings like `top_k`, token limits, and thresholds into the `LightRAG` class, and consistently sourcing parameters from a single location.

env.example CHANGED
@@ -62,6 +62,8 @@ ENABLE_LLM_CACHE=true
62
  # MAX_RELATION_TOKENS=10000
63
  ### control the maximum tokens send to LLM (include entities, raltions and chunks)
64
  # MAX_TOTAL_TOKENS=32000
 
 
65
 
66
  ### Reranker configuration (Set ENABLE_RERANK to true in reranking model is configed)
67
  ENABLE_RERANK=False
 
62
  # MAX_RELATION_TOKENS=10000
63
  ### control the maximum tokens send to LLM (include entities, raltions and chunks)
64
  # MAX_TOTAL_TOKENS=32000
65
+ ### maxumium related chunks grab from single entity or relations
66
+ # RELATED_CHUNK_NUMBER=5
67
 
68
  ### Reranker configuration (Set ENABLE_RERANK to true in reranking model is configed)
69
  ENABLE_RERANK=False
lightrag/api/config.py CHANGED
@@ -14,6 +14,11 @@ from lightrag.constants import (
14
  DEFAULT_TOP_K,
15
  DEFAULT_CHUNK_TOP_K,
16
  DEFAULT_HISTORY_TURNS,
 
 
 
 
 
17
  )
18
 
19
  # use the .env that is inside the current folder
@@ -154,33 +159,6 @@ def parse_args() -> argparse.Namespace:
154
  help="Path to SSL private key file (required if --ssl is enabled)",
155
  )
156
 
157
- parser.add_argument(
158
- "--history-turns",
159
- type=int,
160
- default=get_env_value("HISTORY_TURNS", DEFAULT_HISTORY_TURNS, int),
161
- help="Number of conversation history turns to include (default: from env or 3)",
162
- )
163
-
164
- # Search parameters
165
- parser.add_argument(
166
- "--top-k",
167
- type=int,
168
- default=get_env_value("TOP_K", DEFAULT_TOP_K, int),
169
- help="Number of most similar results to return (default: from env or 60)",
170
- )
171
- parser.add_argument(
172
- "--chunk-top-k",
173
- type=int,
174
- default=get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int),
175
- help="Number of text chunks to retrieve initially from vector search and keep after reranking (default: from env or 5)",
176
- )
177
- parser.add_argument(
178
- "--cosine-threshold",
179
- type=float,
180
- default=get_env_value("COSINE_THRESHOLD", 0.2, float),
181
- help="Cosine similarity threshold (default: from env or 0.4)",
182
- )
183
-
184
  # Ollama model name
185
  parser.add_argument(
186
  "--simulated-model-name",
@@ -312,6 +290,26 @@ def parse_args() -> argparse.Namespace:
312
  args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
313
  args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
316
 
317
  return args
 
14
  DEFAULT_TOP_K,
15
  DEFAULT_CHUNK_TOP_K,
16
  DEFAULT_HISTORY_TURNS,
17
+ DEFAULT_MAX_ENTITY_TOKENS,
18
+ DEFAULT_MAX_RELATION_TOKENS,
19
+ DEFAULT_MAX_TOTAL_TOKENS,
20
+ DEFAULT_COSINE_THRESHOLD,
21
+ DEFAULT_RELATED_CHUNK_NUMBER,
22
  )
23
 
24
  # use the .env that is inside the current folder
 
159
  help="Path to SSL private key file (required if --ssl is enabled)",
160
  )
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  # Ollama model name
163
  parser.add_argument(
164
  "--simulated-model-name",
 
290
  args.rerank_binding_host = get_env_value("RERANK_BINDING_HOST", None)
291
  args.rerank_binding_api_key = get_env_value("RERANK_BINDING_API_KEY", None)
292
 
293
+ # Query configuration
294
+ args.history_turns = get_env_value("HISTORY_TURNS", DEFAULT_HISTORY_TURNS, int)
295
+ args.top_k = get_env_value("TOP_K", DEFAULT_TOP_K, int)
296
+ args.chunk_top_k = get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int)
297
+ args.max_entity_tokens = get_env_value(
298
+ "MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int
299
+ )
300
+ args.max_relation_tokens = get_env_value(
301
+ "MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int
302
+ )
303
+ args.max_total_tokens = get_env_value(
304
+ "MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int
305
+ )
306
+ args.cosine_threshold = get_env_value(
307
+ "COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, float
308
+ )
309
+ args.related_chunk_number = get_env_value(
310
+ "RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int
311
+ )
312
+
313
  ollama_server_infos.LIGHTRAG_MODEL = args.simulated_model_name
314
 
315
  return args
lightrag/constants.py CHANGED
@@ -20,6 +20,8 @@ DEFAULT_MAX_RELATION_TOKENS = 10000
20
  DEFAULT_MAX_TOTAL_TOKENS = 32000
21
  DEFAULT_HISTORY_TURNS = 3
22
  DEFAULT_ENABLE_RERANK = True
 
 
23
 
24
  # Separator for graph fields
25
  GRAPH_FIELD_SEP = "<SEP>"
@@ -28,6 +30,3 @@ GRAPH_FIELD_SEP = "<SEP>"
28
  DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
29
  DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups
30
  DEFAULT_LOG_FILENAME = "lightrag.log" # Default log filename
31
-
32
- # Related Chunk Number for Single Entity or Relation
33
- DEFAULT_RELATED_CHUNK_NUMBER = 5
 
20
  DEFAULT_MAX_TOTAL_TOKENS = 32000
21
  DEFAULT_HISTORY_TURNS = 3
22
  DEFAULT_ENABLE_RERANK = True
23
+ DEFAULT_COSINE_THRESHOLD = 0.2
24
+ DEFAULT_RELATED_CHUNK_NUMBER = 5
25
 
26
  # Separator for graph fields
27
  GRAPH_FIELD_SEP = "<SEP>"
 
30
  DEFAULT_LOG_MAX_BYTES = 10485760 # Default 10MB
31
  DEFAULT_LOG_BACKUP_COUNT = 5 # Default 5 backups
32
  DEFAULT_LOG_FILENAME = "lightrag.log" # Default log filename
 
 
 
lightrag/lightrag.py CHANGED
@@ -24,6 +24,13 @@ from typing import (
24
  from lightrag.constants import (
25
  DEFAULT_MAX_GLEANING,
26
  DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
 
 
 
 
 
 
 
27
  )
28
  from lightrag.utils import get_env_value
29
 
@@ -125,6 +132,42 @@ class LightRAG:
125
  log_level: int | None = field(default=None)
126
  log_file_path: str | None = field(default=None)
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  # Entity extraction
129
  # ---
130
 
 
24
  from lightrag.constants import (
25
  DEFAULT_MAX_GLEANING,
26
  DEFAULT_FORCE_LLM_SUMMARY_ON_MERGE,
27
+ DEFAULT_TOP_K,
28
+ DEFAULT_CHUNK_TOP_K,
29
+ DEFAULT_MAX_ENTITY_TOKENS,
30
+ DEFAULT_MAX_RELATION_TOKENS,
31
+ DEFAULT_MAX_TOTAL_TOKENS,
32
+ DEFAULT_COSINE_THRESHOLD,
33
+ DEFAULT_RELATED_CHUNK_NUMBER,
34
  )
35
  from lightrag.utils import get_env_value
36
 
 
132
  log_level: int | None = field(default=None)
133
  log_file_path: str | None = field(default=None)
134
 
135
+ # Query parameters
136
+ # ---
137
+
138
+ top_k: int = field(default=get_env_value("TOP_K", DEFAULT_TOP_K, int))
139
+ """Number of entities/relations to retrieve for each query."""
140
+
141
+ chunk_top_k: int = field(
142
+ default=get_env_value("CHUNK_TOP_K", DEFAULT_CHUNK_TOP_K, int)
143
+ )
144
+ """Maximum number of chunks in context."""
145
+
146
+ max_entity_tokens: int = field(
147
+ default=get_env_value("MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS, int)
148
+ )
149
+ """Maximum number of tokens for entity in context."""
150
+
151
+ max_relation_tokens: int = field(
152
+ default=get_env_value("MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS, int)
153
+ )
154
+ """Maximum number of tokens for relation in context."""
155
+
156
+ max_total_tokens: int = field(
157
+ default=get_env_value("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS, int)
158
+ )
159
+ """Maximum total tokens in context (including system prompt, entities, relations and chunks)."""
160
+
161
+ cosine_threshold: int = field(
162
+ default=get_env_value("COSINE_THRESHOLD", DEFAULT_COSINE_THRESHOLD, int)
163
+ )
164
+ """Cosine threshold of vector DB retrieval for entities, relations and chunks."""
165
+
166
+ related_chunk_number: int = field(
167
+ default=get_env_value("RELATED_CHUNK_NUMBER", DEFAULT_RELATED_CHUNK_NUMBER, int)
168
+ )
169
+ """Number of related chunks to grab from single entity or relation."""
170
+
171
  # Entity extraction
172
  # ---
173
 
lightrag/operate.py CHANGED
@@ -1908,7 +1908,6 @@ async def _build_query_context(
1908
  ll_keywords,
1909
  knowledge_graph_inst,
1910
  entities_vdb,
1911
- text_chunks_db,
1912
  query_param,
1913
  )
1914
  original_node_datas = node_datas
@@ -1924,7 +1923,6 @@ async def _build_query_context(
1924
  hl_keywords,
1925
  knowledge_graph_inst,
1926
  relationships_vdb,
1927
- text_chunks_db,
1928
  query_param,
1929
  )
1930
  original_edge_datas = edge_datas
@@ -1935,14 +1933,12 @@ async def _build_query_context(
1935
  ll_keywords,
1936
  knowledge_graph_inst,
1937
  entities_vdb,
1938
- text_chunks_db,
1939
  query_param,
1940
  )
1941
  hl_data = await _get_edge_data(
1942
  hl_keywords,
1943
  knowledge_graph_inst,
1944
  relationships_vdb,
1945
- text_chunks_db,
1946
  query_param,
1947
  )
1948
 
@@ -1985,23 +1981,17 @@ async def _build_query_context(
1985
  max_entity_tokens = getattr(
1986
  query_param,
1987
  "max_entity_tokens",
1988
- text_chunks_db.global_config.get(
1989
- "MAX_ENTITY_TOKENS", DEFAULT_MAX_ENTITY_TOKENS
1990
- ),
1991
  )
1992
  max_relation_tokens = getattr(
1993
  query_param,
1994
  "max_relation_tokens",
1995
- text_chunks_db.global_config.get(
1996
- "MAX_RELATION_TOKENS", DEFAULT_MAX_RELATION_TOKENS
1997
- ),
1998
  )
1999
  max_total_tokens = getattr(
2000
  query_param,
2001
  "max_total_tokens",
2002
- text_chunks_db.global_config.get(
2003
- "MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS
2004
- ),
2005
  )
2006
 
2007
  # Truncate entities based on complete JSON serialization
@@ -2095,7 +2085,6 @@ async def _build_query_context(
2095
  final_edge_datas,
2096
  query_param,
2097
  text_chunks_db,
2098
- knowledge_graph_inst,
2099
  )
2100
  )
2101
 
@@ -2255,7 +2244,6 @@ async def _get_node_data(
2255
  query: str,
2256
  knowledge_graph_inst: BaseGraphStorage,
2257
  entities_vdb: BaseVectorStorage,
2258
- text_chunks_db: BaseKVStorage,
2259
  query_param: QueryParam,
2260
  ):
2261
  # get similar entities
@@ -2362,7 +2350,7 @@ async def _find_most_related_text_unit_from_entities(
2362
 
2363
  text_units = [
2364
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])[
2365
- :DEFAULT_RELATED_CHUNK_NUMBER
2366
  ]
2367
  for dp in node_datas
2368
  if dp["source_id"] is not None
@@ -2519,7 +2507,6 @@ async def _get_edge_data(
2519
  keywords,
2520
  knowledge_graph_inst: BaseGraphStorage,
2521
  relationships_vdb: BaseVectorStorage,
2522
- text_chunks_db: BaseKVStorage,
2523
  query_param: QueryParam,
2524
  ):
2525
  logger.info(
@@ -2668,13 +2655,12 @@ async def _find_related_text_unit_from_relationships(
2668
  edge_datas: list[dict],
2669
  query_param: QueryParam,
2670
  text_chunks_db: BaseKVStorage,
2671
- knowledge_graph_inst: BaseGraphStorage,
2672
  ):
2673
  logger.debug(f"Searching text chunks for {len(edge_datas)} relationships")
2674
 
2675
  text_units = [
2676
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])[
2677
- :DEFAULT_RELATED_CHUNK_NUMBER
2678
  ]
2679
  for dp in edge_datas
2680
  if dp["source_id"] is not None
@@ -2761,7 +2747,7 @@ async def naive_query(
2761
  max_total_tokens = getattr(
2762
  query_param,
2763
  "max_total_tokens",
2764
- global_config.get("MAX_TOTAL_TOKENS", DEFAULT_MAX_TOTAL_TOKENS),
2765
  )
2766
 
2767
  # Calculate conversation history tokens
 
1908
  ll_keywords,
1909
  knowledge_graph_inst,
1910
  entities_vdb,
 
1911
  query_param,
1912
  )
1913
  original_node_datas = node_datas
 
1923
  hl_keywords,
1924
  knowledge_graph_inst,
1925
  relationships_vdb,
 
1926
  query_param,
1927
  )
1928
  original_edge_datas = edge_datas
 
1933
  ll_keywords,
1934
  knowledge_graph_inst,
1935
  entities_vdb,
 
1936
  query_param,
1937
  )
1938
  hl_data = await _get_edge_data(
1939
  hl_keywords,
1940
  knowledge_graph_inst,
1941
  relationships_vdb,
 
1942
  query_param,
1943
  )
1944
 
 
1981
  max_entity_tokens = getattr(
1982
  query_param,
1983
  "max_entity_tokens",
1984
+ text_chunks_db.global_config.get("max_entity_tokens", DEFAULT_MAX_ENTITY_TOKENS),
 
 
1985
  )
1986
  max_relation_tokens = getattr(
1987
  query_param,
1988
  "max_relation_tokens",
1989
+ text_chunks_db.global_config.get("max_relation_tokens", DEFAULT_MAX_RELATION_TOKENS),
 
 
1990
  )
1991
  max_total_tokens = getattr(
1992
  query_param,
1993
  "max_total_tokens",
1994
+ text_chunks_db.global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS),
 
 
1995
  )
1996
 
1997
  # Truncate entities based on complete JSON serialization
 
2085
  final_edge_datas,
2086
  query_param,
2087
  text_chunks_db,
 
2088
  )
2089
  )
2090
 
 
2244
  query: str,
2245
  knowledge_graph_inst: BaseGraphStorage,
2246
  entities_vdb: BaseVectorStorage,
 
2247
  query_param: QueryParam,
2248
  ):
2249
  # get similar entities
 
2350
 
2351
  text_units = [
2352
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])[
2353
+ : text_chunks_db.global_config.get("related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER)
2354
  ]
2355
  for dp in node_datas
2356
  if dp["source_id"] is not None
 
2507
  keywords,
2508
  knowledge_graph_inst: BaseGraphStorage,
2509
  relationships_vdb: BaseVectorStorage,
 
2510
  query_param: QueryParam,
2511
  ):
2512
  logger.info(
 
2655
  edge_datas: list[dict],
2656
  query_param: QueryParam,
2657
  text_chunks_db: BaseKVStorage,
 
2658
  ):
2659
  logger.debug(f"Searching text chunks for {len(edge_datas)} relationships")
2660
 
2661
  text_units = [
2662
  split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])[
2663
+ : text_chunks_db.global_config.get("related_chunk_number", DEFAULT_RELATED_CHUNK_NUMBER)
2664
  ]
2665
  for dp in edge_datas
2666
  if dp["source_id"] is not None
 
2747
  max_total_tokens = getattr(
2748
  query_param,
2749
  "max_total_tokens",
2750
+ global_config.get("max_total_tokens", DEFAULT_MAX_TOTAL_TOKENS),
2751
  )
2752
 
2753
  # Calculate conversation history tokens