yangdx commited on
Commit
99a3d9e
·
1 Parent(s): 528d6fd

Translate comment to English

Browse files
lightrag/api/lightrag_ollama.py CHANGED
@@ -27,15 +27,15 @@ from dotenv import load_dotenv
27
  load_dotenv()
28
 
29
  def estimate_tokens(text: str) -> int:
30
- """估算文本的token数量
31
- 中文每字约1.5个token
32
- 英文每字约0.25个token
33
  """
34
- # 使用正则表达式分别匹配中文字符和非中文字符
35
  chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
36
  non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text))
37
 
38
- # 计算估算的token数量
39
  tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
40
 
41
  return int(tokens)
@@ -241,7 +241,7 @@ class DocumentManager:
241
  class SearchMode(str, Enum):
242
  naive = "naive"
243
  local = "local"
244
- global_ = "global" # 使用 global_ 因为 global Python 保留关键字,但枚举值会转换为字符串 "global"
245
  hybrid = "hybrid"
246
  mix = "mix"
247
 
@@ -254,7 +254,7 @@ class OllamaMessage(BaseModel):
254
  class OllamaChatRequest(BaseModel):
255
  model: str = LIGHTRAG_MODEL
256
  messages: List[OllamaMessage]
257
- stream: bool = True # 默认为流式模式
258
  options: Optional[Dict[str, Any]] = None
259
 
260
  class OllamaChatResponse(BaseModel):
@@ -490,11 +490,11 @@ def create_app(args):
490
  ),
491
  )
492
 
493
- # 如果响应是字符串(比如命中缓存),直接返回
494
  if isinstance(response, str):
495
  return QueryResponse(response=response)
496
 
497
- # 如果是异步生成器,根据stream参数决定是否流式返回
498
  if request.stream:
499
  result = ""
500
  async for chunk in response:
@@ -511,7 +511,7 @@ def create_app(args):
511
  @app.post("/query/stream", dependencies=[Depends(optional_api_key)])
512
  async def query_text_stream(request: QueryRequest):
513
  try:
514
- response = await rag.aquery( # 使用 aquery 而不是 query,并添加 await
515
  request.query,
516
  param=QueryParam(
517
  mode=request.mode,
@@ -691,7 +691,7 @@ def create_app(args):
691
 
692
  for prefix, mode in mode_map.items():
693
  if query.startswith(prefix):
694
- # 移除前缀后,清理开头的额外空格
695
  cleaned_query = query[len(prefix):].lstrip()
696
  return cleaned_query, mode
697
 
@@ -699,17 +699,14 @@ def create_app(args):
699
 
700
  @app.post("/api/chat")
701
  async def chat(raw_request: Request, request: OllamaChatRequest):
702
- # # 打印原始请求数据
703
- # body = await raw_request.body()
704
- # logging.info(f"收到 /api/chat 原始请求: {body.decode('utf-8')}")
705
  """Handle chat completion requests"""
706
  try:
707
- # 获取所有消息内容
708
  messages = request.messages
709
  if not messages:
710
  raise HTTPException(status_code=400, detail="No messages provided")
711
 
712
- # 获取最后一条消息作为查询
713
  query = messages[-1].content
714
 
715
  # 解析查询模式
@@ -723,7 +720,7 @@ def create_app(args):
723
 
724
  # 调用RAG进行查询
725
  query_param = QueryParam(
726
- mode=mode, # 使用解析出的模式,如果没有前缀则为默认的 hybrid
727
  stream=request.stream,
728
  only_need_context=False
729
  )
@@ -731,7 +728,7 @@ def create_app(args):
731
  if request.stream:
732
  from fastapi.responses import StreamingResponse
733
 
734
- response = await rag.aquery( # 需要 await 来获取异步生成器
735
  cleaned_query,
736
  param=query_param
737
  )
@@ -742,9 +739,9 @@ def create_app(args):
742
  last_chunk_time = None
743
  total_response = ""
744
 
745
- # 确保 response 是异步生成器
746
  if isinstance(response, str):
747
- # 如果是字符串,分两次发送
748
  first_chunk_time = time.time_ns()
749
  last_chunk_time = first_chunk_time
750
  total_response = response
 
27
  load_dotenv()
28
 
29
  def estimate_tokens(text: str) -> int:
30
+ """Estimate the number of tokens in text
31
+ Chinese characters: approximately 1.5 tokens per character
32
+ English characters: approximately 0.25 tokens per character
33
  """
34
+ # Use regex to match Chinese and non-Chinese characters separately
35
  chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
36
  non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text))
37
 
38
+ # Calculate estimated token count
39
  tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
40
 
41
  return int(tokens)
 
241
  class SearchMode(str, Enum):
242
  naive = "naive"
243
  local = "local"
244
+ global_ = "global" # Using global_ because global is a Python reserved keyword, but enum value will be converted to string "global"
245
  hybrid = "hybrid"
246
  mix = "mix"
247
 
 
254
  class OllamaChatRequest(BaseModel):
255
  model: str = LIGHTRAG_MODEL
256
  messages: List[OllamaMessage]
257
+ stream: bool = True # Default to streaming mode
258
  options: Optional[Dict[str, Any]] = None
259
 
260
  class OllamaChatResponse(BaseModel):
 
490
  ),
491
  )
492
 
493
+ # If response is a string (e.g. cache hit), return directly
494
  if isinstance(response, str):
495
  return QueryResponse(response=response)
496
 
497
+ # If it's an async generator, decide whether to stream based on stream parameter
498
  if request.stream:
499
  result = ""
500
  async for chunk in response:
 
511
  @app.post("/query/stream", dependencies=[Depends(optional_api_key)])
512
  async def query_text_stream(request: QueryRequest):
513
  try:
514
+ response = await rag.aquery( # Use aquery instead of query, and add await
515
  request.query,
516
  param=QueryParam(
517
  mode=request.mode,
 
691
 
692
  for prefix, mode in mode_map.items():
693
  if query.startswith(prefix):
694
+ # After removing prefix an leading spaces
695
  cleaned_query = query[len(prefix):].lstrip()
696
  return cleaned_query, mode
697
 
 
699
 
700
  @app.post("/api/chat")
701
  async def chat(raw_request: Request, request: OllamaChatRequest):
 
 
 
702
  """Handle chat completion requests"""
703
  try:
704
+ # Get all messages
705
  messages = request.messages
706
  if not messages:
707
  raise HTTPException(status_code=400, detail="No messages provided")
708
 
709
+ # Get the last message as query
710
  query = messages[-1].content
711
 
712
  # 解析查询模式
 
720
 
721
  # 调用RAG进行查询
722
  query_param = QueryParam(
723
+ mode=mode,
724
  stream=request.stream,
725
  only_need_context=False
726
  )
 
728
  if request.stream:
729
  from fastapi.responses import StreamingResponse
730
 
731
+ response = await rag.aquery( # Need await to get async generator
732
  cleaned_query,
733
  param=query_param
734
  )
 
739
  last_chunk_time = None
740
  total_response = ""
741
 
742
+ # Ensure response is an async generator
743
  if isinstance(response, str):
744
+ # If it's a string, send in two parts
745
  first_chunk_time = time.time_ns()
746
  last_chunk_time = first_chunk_time
747
  total_response = response
test_lightrag_ollama_chat.py CHANGED
@@ -1,12 +1,12 @@
1
  """
2
- LightRAG Ollama 兼容接口测试脚本
3
 
4
- 这个脚本测试 LightRAG Ollama 兼容接口,包括:
5
- 1. 基本功能测试(流式和非流式响应)
6
- 2. 查询模式测试(localglobalnaivehybrid
7
- 3. 错误处理测试(包括流式和非流式场景)
8
 
9
- 所有响应都使用 JSON Lines 格式,符合 Ollama API 规范。
10
  """
11
 
12
  import requests
@@ -24,20 +24,10 @@ class OutputControl:
24
 
25
  @classmethod
26
  def set_verbose(cls, verbose: bool) -> None:
27
- """设置输出详细程度
28
-
29
- Args:
30
- verbose: True 为详细模式,False 为静默模式
31
- """
32
  cls._verbose = verbose
33
 
34
  @classmethod
35
  def is_verbose(cls) -> bool:
36
- """获取当前输出模式
37
-
38
- Returns:
39
- 当前是否为详细模式
40
- """
41
  return cls._verbose
42
 
43
  @dataclass
@@ -48,9 +38,8 @@ class TestResult:
48
  duration: float
49
  error: Optional[str] = None
50
  timestamp: str = ""
51
-
52
  def __post_init__(self):
53
- """初始化后设置时间戳"""
54
  if not self.timestamp:
55
  self.timestamp = datetime.now().isoformat()
56
 
@@ -59,14 +48,13 @@ class TestStats:
59
  def __init__(self):
60
  self.results: List[TestResult] = []
61
  self.start_time = datetime.now()
62
-
63
  def add_result(self, result: TestResult):
64
- """添加测试结果"""
65
  self.results.append(result)
66
-
67
  def export_results(self, path: str = "test_results.json"):
68
  """导出测试结果到 JSON 文件
69
-
70
  Args:
71
  path: 输出文件路径
72
  """
@@ -81,25 +69,24 @@ class TestStats:
81
  "total_duration": sum(r.duration for r in self.results)
82
  }
83
  }
84
-
85
  with open(path, "w", encoding="utf-8") as f:
86
  json.dump(results_data, f, ensure_ascii=False, indent=2)
87
  print(f"\n测试结果已保存到: {path}")
88
-
89
  def print_summary(self):
90
- """打印测试统计摘要"""
91
  total = len(self.results)
92
  passed = sum(1 for r in self.results if r.success)
93
  failed = total - passed
94
  duration = sum(r.duration for r in self.results)
95
-
96
  print("\n=== 测试结果摘要 ===")
97
  print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
98
  print(f"总用时: {duration:.2f}秒")
99
  print(f"总计: {total} 个测试")
100
  print(f"通过: {passed} 个")
101
  print(f"失败: {failed} 个")
102
-
103
  if failed > 0:
104
  print("\n失败的测试:")
105
  for result in self.results:
@@ -125,15 +112,15 @@ DEFAULT_CONFIG = {
125
 
126
  def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
127
  """发送 HTTP 请求,支持重试机制
128
-
129
  Args:
130
  url: 请求 URL
131
  data: 请求数据
132
  stream: 是否使用流式响应
133
-
134
  Returns:
135
- requests.Response 对象
136
-
137
  Raises:
138
  requests.exceptions.RequestException: 请求失败且重试次数用完
139
  """
@@ -141,7 +128,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
141
  max_retries = server_config["max_retries"]
142
  retry_delay = server_config["retry_delay"]
143
  timeout = server_config["timeout"]
144
-
145
  for attempt in range(max_retries):
146
  try:
147
  response = requests.post(
@@ -159,10 +146,10 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
159
 
160
  def load_config() -> Dict[str, Any]:
161
  """加载配置文件
162
-
163
  首先尝试从当前目录的 config.json 加载,
164
  如果不存在则使用默认配置
165
-
166
  Returns:
167
  配置字典
168
  """
@@ -174,7 +161,7 @@ def load_config() -> Dict[str, Any]:
174
 
175
  def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
176
  """格式化打印 JSON 响应数据
177
-
178
  Args:
179
  data: 要打印的数据字典
180
  title: 打印的标题
@@ -199,12 +186,12 @@ def create_request_data(
199
  model: str = None
200
  ) -> Dict[str, Any]:
201
  """创建基本的请求数据
202
-
203
  Args:
204
  content: 用户消息内容
205
  stream: 是否使用流式响应
206
  model: 模型名称
207
-
208
  Returns:
209
  包含完整请求数据的字典
210
  """
@@ -224,7 +211,7 @@ STATS = TestStats()
224
 
225
  def run_test(func: Callable, name: str) -> None:
226
  """运行测试并记录结果
227
-
228
  Args:
229
  func: 测试函数
230
  name: 测试名称
@@ -246,21 +233,21 @@ def test_non_stream_chat():
246
  CONFIG["test_cases"]["basic"]["query"],
247
  stream=False
248
  )
249
-
250
  # 发送请求
251
  response = make_request(url, data)
252
-
253
  # 打印响应
254
  if OutputControl.is_verbose():
255
  print("\n=== 非流式调用响应 ===")
256
  response_json = response.json()
257
-
258
  # 打印响应内容
259
  print_json_response({
260
  "model": response_json["model"],
261
  "message": response_json["message"]
262
  }, "响应内容")
263
-
264
  # # 打印性能统计
265
  # print_json_response({
266
  # "total_duration": response_json["total_duration"],
@@ -273,7 +260,7 @@ def test_non_stream_chat():
273
 
274
  def test_stream_chat():
275
  """测试流式调用 /api/chat 接口
276
-
277
  使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
278
  响应格式:
279
  {
@@ -286,7 +273,7 @@ def test_stream_chat():
286
  },
287
  "done": false
288
  }
289
-
290
  最后一条消息会包含性能统计信息,done 为 true。
291
  """
292
  url = get_base_url()
@@ -294,10 +281,10 @@ def test_stream_chat():
294
  CONFIG["test_cases"]["basic"]["query"],
295
  stream=True
296
  )
297
-
298
  # 发送请求并获取流式响应
299
  response = make_request(url, data, stream=True)
300
-
301
  if OutputControl.is_verbose():
302
  print("\n=== 流式调用响应 ===")
303
  output_buffer = []
@@ -321,24 +308,24 @@ def test_stream_chat():
321
  print("Error decoding JSON from response line")
322
  finally:
323
  response.close() # 确保关闭响应连接
324
-
325
  # 打印一个换行
326
  print()
327
 
328
  def test_query_modes():
329
  """测试不同的查询模式前缀
330
-
331
  支持的查询模式:
332
  - /local: 本地检索模式,只在相关度高的文档中搜索
333
  - /global: 全局检索模式,在所有文档中搜索
334
  - /naive: 朴素模式,不使用任何优化策略
335
  - /hybrid: 混合模式(默认),结合多种策略
336
-
337
  每个模式都会返回相同格式的响应,但检索策略不同。
338
  """
339
  url = get_base_url()
340
  modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式
341
-
342
  for mode in modes:
343
  if OutputControl.is_verbose():
344
  print(f"\n=== 测试 /{mode} 模式 ===")
@@ -346,11 +333,11 @@ def test_query_modes():
346
  f"/{mode} {CONFIG['test_cases']['basic']['query']}",
347
  stream=False
348
  )
349
-
350
  # 发送请求
351
  response = make_request(url, data)
352
  response_json = response.json()
353
-
354
  # 打印响应内容
355
  print_json_response({
356
  "model": response_json["model"],
@@ -359,13 +346,13 @@ def test_query_modes():
359
 
360
  def create_error_test_data(error_type: str) -> Dict[str, Any]:
361
  """创建用于错误测试的请求数据
362
-
363
  Args:
364
  error_type: 错误类型,支持:
365
  - empty_messages: 空消息列表
366
  - invalid_role: 无效的角色字段
367
  - missing_content: 缺少内容字段
368
-
369
  Returns:
370
  包含错误数据的请求字典
371
  """
@@ -399,19 +386,19 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
399
 
400
  def test_stream_error_handling():
401
  """测试流式响应的错误处理
402
-
403
  测试场景:
404
  1. 空消息列表
405
  2. 消息格式错误(缺少必需字段)
406
-
407
  错误响应会立即返回,不会建立流式连接。
408
  状态码应该是 4xx,并返回详细的错误信息。
409
  """
410
  url = get_base_url()
411
-
412
  if OutputControl.is_verbose():
413
  print("\n=== 测试流式响应错误处理 ===")
414
-
415
  # 测试空消息列表
416
  if OutputControl.is_verbose():
417
  print("\n--- 测试空消息列表(流式)---")
@@ -421,7 +408,7 @@ def test_stream_error_handling():
421
  if response.status_code != 200:
422
  print_json_response(response.json(), "错误信息")
423
  response.close()
424
-
425
  # 测试无效角色字段
426
  if OutputControl.is_verbose():
427
  print("\n--- 测试无效角色字段(流式)---")
@@ -444,23 +431,23 @@ def test_stream_error_handling():
444
 
445
  def test_error_handling():
446
  """测试非流式响应的错误处理
447
-
448
  测试场景:
449
  1. 空消息列表
450
  2. 消息格式错误(缺少必需字段)
451
-
452
  错误响应格式:
453
  {
454
  "detail": "错误描述"
455
  }
456
-
457
  所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
458
  """
459
  url = get_base_url()
460
-
461
  if OutputControl.is_verbose():
462
  print("\n=== 测试错误处理 ===")
463
-
464
  # 测试空消息列表
465
  if OutputControl.is_verbose():
466
  print("\n--- 测试空消息列表 ---")
@@ -469,7 +456,7 @@ def test_error_handling():
469
  response = make_request(url, data)
470
  print(f"状态码: {response.status_code}")
471
  print_json_response(response.json(), "错误信息")
472
-
473
  # 测试无效角色字段
474
  if OutputControl.is_verbose():
475
  print("\n--- 测试无效���色字段 ---")
@@ -490,7 +477,7 @@ def test_error_handling():
490
 
491
  def get_test_cases() -> Dict[str, Callable]:
492
  """获取所有可用的测试用例
493
-
494
  Returns:
495
  测试名称到测试函数的映射字典
496
  """
@@ -564,21 +551,21 @@ def parse_args() -> argparse.Namespace:
564
 
565
  if __name__ == "__main__":
566
  args = parse_args()
567
-
568
  # 设置输出模式
569
  OutputControl.set_verbose(not args.quiet)
570
-
571
  # 如果指定了查询内容,更新配置
572
  if args.ask:
573
  CONFIG["test_cases"]["basic"]["query"] = args.ask
574
-
575
  # 如果指定了创建配置文件
576
  if args.init_config:
577
  create_default_config()
578
  exit(0)
579
-
580
  test_cases = get_test_cases()
581
-
582
  try:
583
  if "all" in args.tests:
584
  # 运行所有测试
@@ -586,11 +573,11 @@ if __name__ == "__main__":
586
  print("\n【基本功能测试】")
587
  run_test(test_non_stream_chat, "非流式调用测试")
588
  run_test(test_stream_chat, "流式调用测试")
589
-
590
  if OutputControl.is_verbose():
591
  print("\n【查询模式测试】")
592
  run_test(test_query_modes, "查询模式测试")
593
-
594
  if OutputControl.is_verbose():
595
  print("\n【错误处理测试】")
596
  run_test(test_error_handling, "错误处理测试")
 
1
  """
2
+ LightRAG Ollama Compatibility Interface Test Script
3
 
4
+ This script tests the LightRAG's Ollama compatibility interface, including:
5
+ 1. Basic functionality tests (streaming and non-streaming responses)
6
+ 2. Query mode tests (local, global, naive, hybrid)
7
+ 3. Error handling tests (including streaming and non-streaming scenarios)
8
 
9
+ All responses use the JSON Lines format, complying with the Ollama API specification.
10
  """
11
 
12
  import requests
 
24
 
25
  @classmethod
26
  def set_verbose(cls, verbose: bool) -> None:
 
 
 
 
 
27
  cls._verbose = verbose
28
 
29
  @classmethod
30
  def is_verbose(cls) -> bool:
 
 
 
 
 
31
  return cls._verbose
32
 
33
  @dataclass
 
38
  duration: float
39
  error: Optional[str] = None
40
  timestamp: str = ""
41
+
42
  def __post_init__(self):
 
43
  if not self.timestamp:
44
  self.timestamp = datetime.now().isoformat()
45
 
 
48
  def __init__(self):
49
  self.results: List[TestResult] = []
50
  self.start_time = datetime.now()
51
+
52
  def add_result(self, result: TestResult):
 
53
  self.results.append(result)
54
+
55
  def export_results(self, path: str = "test_results.json"):
56
  """导出测试结果到 JSON 文件
57
+
58
  Args:
59
  path: 输出文件路径
60
  """
 
69
  "total_duration": sum(r.duration for r in self.results)
70
  }
71
  }
72
+
73
  with open(path, "w", encoding="utf-8") as f:
74
  json.dump(results_data, f, ensure_ascii=False, indent=2)
75
  print(f"\n测试结果已保存到: {path}")
76
+
77
  def print_summary(self):
 
78
  total = len(self.results)
79
  passed = sum(1 for r in self.results if r.success)
80
  failed = total - passed
81
  duration = sum(r.duration for r in self.results)
82
+
83
  print("\n=== 测试结果摘要 ===")
84
  print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
85
  print(f"总用时: {duration:.2f}秒")
86
  print(f"总计: {total} 个测试")
87
  print(f"通过: {passed} 个")
88
  print(f"失败: {failed} 个")
89
+
90
  if failed > 0:
91
  print("\n失败的测试:")
92
  for result in self.results:
 
112
 
113
  def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
114
  """发送 HTTP 请求,支持重试机制
115
+
116
  Args:
117
  url: 请求 URL
118
  data: 请求数据
119
  stream: 是否使用流式响应
120
+
121
  Returns:
122
+ requests.Response: 对象
123
+
124
  Raises:
125
  requests.exceptions.RequestException: 请求失败且重试次数用完
126
  """
 
128
  max_retries = server_config["max_retries"]
129
  retry_delay = server_config["retry_delay"]
130
  timeout = server_config["timeout"]
131
+
132
  for attempt in range(max_retries):
133
  try:
134
  response = requests.post(
 
146
 
147
  def load_config() -> Dict[str, Any]:
148
  """加载配置文件
149
+
150
  首先尝试从当前目录的 config.json 加载,
151
  如果不存在则使用默认配置
152
+
153
  Returns:
154
  配置字典
155
  """
 
161
 
162
  def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
163
  """格式化打印 JSON 响应数据
164
+
165
  Args:
166
  data: 要打印的数据字典
167
  title: 打印的标题
 
186
  model: str = None
187
  ) -> Dict[str, Any]:
188
  """创建基本的请求数据
189
+
190
  Args:
191
  content: 用户消息内容
192
  stream: 是否使用流式响应
193
  model: 模型名称
194
+
195
  Returns:
196
  包含完整请求数据的字典
197
  """
 
211
 
212
  def run_test(func: Callable, name: str) -> None:
213
  """运行测试并记录结果
214
+
215
  Args:
216
  func: 测试函数
217
  name: 测试名称
 
233
  CONFIG["test_cases"]["basic"]["query"],
234
  stream=False
235
  )
236
+
237
  # 发送请求
238
  response = make_request(url, data)
239
+
240
  # 打印响应
241
  if OutputControl.is_verbose():
242
  print("\n=== 非流式调用响应 ===")
243
  response_json = response.json()
244
+
245
  # 打印响应内容
246
  print_json_response({
247
  "model": response_json["model"],
248
  "message": response_json["message"]
249
  }, "响应内容")
250
+
251
  # # 打印性能统计
252
  # print_json_response({
253
  # "total_duration": response_json["total_duration"],
 
260
 
261
  def test_stream_chat():
262
  """测试流式调用 /api/chat 接口
263
+
264
  使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
265
  响应格式:
266
  {
 
273
  },
274
  "done": false
275
  }
276
+
277
  最后一条消息会包含性能统计信息,done 为 true。
278
  """
279
  url = get_base_url()
 
281
  CONFIG["test_cases"]["basic"]["query"],
282
  stream=True
283
  )
284
+
285
  # 发送请求并获取流式响应
286
  response = make_request(url, data, stream=True)
287
+
288
  if OutputControl.is_verbose():
289
  print("\n=== 流式调用响应 ===")
290
  output_buffer = []
 
308
  print("Error decoding JSON from response line")
309
  finally:
310
  response.close() # 确保关闭响应连接
311
+
312
  # 打印一个换行
313
  print()
314
 
315
  def test_query_modes():
316
  """测试不同的查询模式前缀
317
+
318
  支持的查询模式:
319
  - /local: 本地检索模式,只在相关度高的文档中搜索
320
  - /global: 全局检索模式,在所有文档中搜索
321
  - /naive: 朴素模式,不使用任何优化策略
322
  - /hybrid: 混合模式(默认),结合多种策略
323
+
324
  每个模式都会返回相同格式的响应,但检索策略不同。
325
  """
326
  url = get_base_url()
327
  modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式
328
+
329
  for mode in modes:
330
  if OutputControl.is_verbose():
331
  print(f"\n=== 测试 /{mode} 模式 ===")
 
333
  f"/{mode} {CONFIG['test_cases']['basic']['query']}",
334
  stream=False
335
  )
336
+
337
  # 发送请求
338
  response = make_request(url, data)
339
  response_json = response.json()
340
+
341
  # 打印响应内容
342
  print_json_response({
343
  "model": response_json["model"],
 
346
 
347
  def create_error_test_data(error_type: str) -> Dict[str, Any]:
348
  """创建用于错误测试的请求数据
349
+
350
  Args:
351
  error_type: 错误类型,支持:
352
  - empty_messages: 空消息列表
353
  - invalid_role: 无效的角色字段
354
  - missing_content: 缺少内容字段
355
+
356
  Returns:
357
  包含错误数据的请求字典
358
  """
 
386
 
387
  def test_stream_error_handling():
388
  """测试流式响应的错误处理
389
+
390
  测试场景:
391
  1. 空消息列表
392
  2. 消息格式错误(缺少必需字段)
393
+
394
  错误响应会立即返回,不会建立流式连接。
395
  状态码应该是 4xx,并返回详细的错误信息。
396
  """
397
  url = get_base_url()
398
+
399
  if OutputControl.is_verbose():
400
  print("\n=== 测试流式响应错误处理 ===")
401
+
402
  # 测试空消息列表
403
  if OutputControl.is_verbose():
404
  print("\n--- 测试空消息列表(流式)---")
 
408
  if response.status_code != 200:
409
  print_json_response(response.json(), "错误信息")
410
  response.close()
411
+
412
  # 测试无效角色字段
413
  if OutputControl.is_verbose():
414
  print("\n--- 测试无效角色字段(流式)---")
 
431
 
432
  def test_error_handling():
433
  """测试非流式响应的错误处理
434
+
435
  测试场景:
436
  1. 空消息列表
437
  2. 消息格式错误(缺少必需字段)
438
+
439
  错误响应格式:
440
  {
441
  "detail": "错误描述"
442
  }
443
+
444
  所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
445
  """
446
  url = get_base_url()
447
+
448
  if OutputControl.is_verbose():
449
  print("\n=== 测试错误处理 ===")
450
+
451
  # 测试空消息列表
452
  if OutputControl.is_verbose():
453
  print("\n--- 测试空消息列表 ---")
 
456
  response = make_request(url, data)
457
  print(f"状态码: {response.status_code}")
458
  print_json_response(response.json(), "错误信息")
459
+
460
  # 测试无效角色字段
461
  if OutputControl.is_verbose():
462
  print("\n--- 测试无效���色字段 ---")
 
477
 
478
  def get_test_cases() -> Dict[str, Callable]:
479
  """获取所有可用的测试用例
480
+
481
  Returns:
482
  测试名称到测试函数的映射字典
483
  """
 
551
 
552
  if __name__ == "__main__":
553
  args = parse_args()
554
+
555
  # 设置输出模式
556
  OutputControl.set_verbose(not args.quiet)
557
+
558
  # 如果指定了查询内容,更新配置
559
  if args.ask:
560
  CONFIG["test_cases"]["basic"]["query"] = args.ask
561
+
562
  # 如果指定了创建配置文件
563
  if args.init_config:
564
  create_default_config()
565
  exit(0)
566
+
567
  test_cases = get_test_cases()
568
+
569
  try:
570
  if "all" in args.tests:
571
  # 运行所有测试
 
573
  print("\n【基本功能测试】")
574
  run_test(test_non_stream_chat, "非流式调用测试")
575
  run_test(test_stream_chat, "流式调用测试")
576
+
577
  if OutputControl.is_verbose():
578
  print("\n【查询模式测试】")
579
  run_test(test_query_modes, "查询模式测试")
580
+
581
  if OutputControl.is_verbose():
582
  print("\n【错误处理测试】")
583
  run_test(test_error_handling, "错误处理测试")