yangdx commited on
Commit
344ad52
·
1 Parent(s): e83d6af

Removed query mode parsing and related tests

Browse files

- Removed query mode parsing logic
- Removed test_generate_query_modes
- Simplified generate endpoint
- Updated test cases list
- Cleaned up unused code

lightrag/api/lightrag_server.py CHANGED
@@ -1260,17 +1260,13 @@ def create_app(args):
1260
  async def generate(raw_request: Request, request: OllamaGenerateRequest):
1261
  """Handle generate completion requests"""
1262
  try:
1263
- # 获取查询内容
1264
  query = request.prompt
1265
-
1266
- # 解析查询模式
1267
- cleaned_query, mode = parse_query_mode(query)
1268
-
1269
  # 开始计时
1270
  start_time = time.time_ns()
1271
 
1272
  # 计算输入token数量
1273
- prompt_tokens = estimate_tokens(cleaned_query)
1274
 
1275
  # 直接使用 llm_model_func 进行查询
1276
  if request.system:
@@ -1280,7 +1276,7 @@ def create_app(args):
1280
  from fastapi.responses import StreamingResponse
1281
 
1282
  response = await rag.llm_model_func(
1283
- cleaned_query,
1284
  stream=True,
1285
  **rag.llm_model_kwargs
1286
  )
@@ -1378,7 +1374,7 @@ def create_app(args):
1378
  else:
1379
  first_chunk_time = time.time_ns()
1380
  response_text = await rag.llm_model_func(
1381
- cleaned_query,
1382
  stream=False,
1383
  **rag.llm_model_kwargs
1384
  )
 
1260
  async def generate(raw_request: Request, request: OllamaGenerateRequest):
1261
  """Handle generate completion requests"""
1262
  try:
 
1263
  query = request.prompt
1264
+
 
 
 
1265
  # 开始计时
1266
  start_time = time.time_ns()
1267
 
1268
  # 计算输入token数量
1269
+ prompt_tokens = estimate_tokens(query)
1270
 
1271
  # 直接使用 llm_model_func 进行查询
1272
  if request.system:
 
1276
  from fastapi.responses import StreamingResponse
1277
 
1278
  response = await rag.llm_model_func(
1279
+ query,
1280
  stream=True,
1281
  **rag.llm_model_kwargs
1282
  )
 
1374
  else:
1375
  first_chunk_time = time.time_ns()
1376
  response_text = await rag.llm_model_func(
1377
+ query,
1378
  stream=False,
1379
  **rag.llm_model_kwargs
1380
  )
test_lightrag_ollama_chat.py CHANGED
@@ -679,32 +679,6 @@ def test_generate_concurrent() -> None:
679
  print(f"\nRequest {i} result:")
680
  print_json_response(result)
681
 
682
- def test_generate_query_modes() -> None:
683
- """Test different query mode prefixes for generate endpoint"""
684
- url = get_base_url("generate")
685
- modes = ["local", "global", "naive", "hybrid", "mix"]
686
-
687
- for mode in modes:
688
- if OutputControl.is_verbose():
689
- print(f"\n=== Testing /{mode} mode for generate ===")
690
- data = create_generate_request_data(
691
- f"/{mode} {CONFIG['test_cases']['generate']['query']}",
692
- stream=False
693
- )
694
-
695
- # Send request
696
- response = make_request(url, data)
697
- response_json = response.json()
698
-
699
- # Print response content
700
- print_json_response(
701
- {
702
- "model": response_json["model"],
703
- "response": response_json["response"],
704
- "done": response_json["done"]
705
- }
706
- )
707
-
708
  def get_test_cases() -> Dict[str, Callable]:
709
  """Get all available test cases
710
  Returns:
@@ -719,7 +693,6 @@ def get_test_cases() -> Dict[str, Callable]:
719
  "non_stream_generate": test_non_stream_generate,
720
  "stream_generate": test_stream_generate,
721
  "generate_with_system": test_generate_with_system,
722
- "generate_modes": test_generate_query_modes,
723
  "generate_errors": test_generate_error_handling,
724
  "generate_stats": test_generate_performance_stats,
725
  "generate_concurrent": test_generate_concurrent
@@ -821,7 +794,6 @@ if __name__ == "__main__":
821
  run_test(test_non_stream_generate, "Non-streaming Generate Test")
822
  run_test(test_stream_generate, "Streaming Generate Test")
823
  run_test(test_generate_with_system, "Generate with System Prompt Test")
824
- run_test(test_generate_query_modes, "Generate Query Mode Test")
825
  run_test(test_generate_error_handling, "Generate Error Handling Test")
826
  run_test(test_generate_performance_stats, "Generate Performance Stats Test")
827
  run_test(test_generate_concurrent, "Generate Concurrent Test")
 
679
  print(f"\nRequest {i} result:")
680
  print_json_response(result)
681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  def get_test_cases() -> Dict[str, Callable]:
683
  """Get all available test cases
684
  Returns:
 
693
  "non_stream_generate": test_non_stream_generate,
694
  "stream_generate": test_stream_generate,
695
  "generate_with_system": test_generate_with_system,
 
696
  "generate_errors": test_generate_error_handling,
697
  "generate_stats": test_generate_performance_stats,
698
  "generate_concurrent": test_generate_concurrent
 
794
  run_test(test_non_stream_generate, "Non-streaming Generate Test")
795
  run_test(test_stream_generate, "Streaming Generate Test")
796
  run_test(test_generate_with_system, "Generate with System Prompt Test")
 
797
  run_test(test_generate_error_handling, "Generate Error Handling Test")
798
  run_test(test_generate_performance_stats, "Generate Performance Stats Test")
799
  run_test(test_generate_concurrent, "Generate Concurrent Test")