yangdx commited on
Commit
9f80fd7
·
1 Parent(s): 49a6af5

完善测试用例

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. test_lightrag_ollama_chat.py +560 -54
.gitignore CHANGED
@@ -21,3 +21,4 @@ rag_storage
21
  venv/
22
  examples/input/
23
  examples/output/
 
 
21
  venv/
22
  examples/input/
23
  examples/output/
24
+ test_results.json
test_lightrag_ollama_chat.py CHANGED
@@ -1,96 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import requests
2
  import json
3
- import sseclient
 
 
 
 
 
 
4
 
5
- def test_non_stream_chat():
6
- """测试非流式调用 /api/chat 接口"""
7
- url = "http://localhost:9621/api/chat"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # 构造请求数据
10
- data = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  "model": "lightrag:latest",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  "messages": [
13
  {
14
  "role": "user",
15
- "content": "孙悟空"
16
  }
17
  ],
18
- "stream": False
19
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # 发送请求
22
- response = requests.post(url, json=data)
23
 
24
  # 打印响应
25
- print("\n=== 非流式调用响应 ===")
 
26
  response_json = response.json()
27
 
28
- # 打印消息内容
29
- print("=== 响应内容 ===")
30
- print(json.dumps({
31
  "model": response_json["model"],
32
  "message": response_json["message"]
33
- }, ensure_ascii=False, indent=2))
34
 
35
  # 打印性能统计
36
- print("\n=== 性能统计 ===")
37
- stats = {
38
  "total_duration": response_json["total_duration"],
39
  "load_duration": response_json["load_duration"],
40
  "prompt_eval_count": response_json["prompt_eval_count"],
41
  "prompt_eval_duration": response_json["prompt_eval_duration"],
42
  "eval_count": response_json["eval_count"],
43
  "eval_duration": response_json["eval_duration"]
44
- }
45
- print(json.dumps(stats, ensure_ascii=False, indent=2))
46
 
47
  def test_stream_chat():
48
- """测试流式调用 /api/chat 接口"""
49
- url = "http://localhost:9621/api/chat"
50
 
51
- # 构造请求数据
52
- data = {
 
53
  "model": "lightrag:latest",
54
- "messages": [
55
- {
56
- "role": "user",
57
- "content": "孙悟空有什么法力,性格特征是什么"
58
- }
59
- ],
60
- "stream": True
61
  }
62
 
63
- # 发送请求并获取 SSE 流
64
- response = requests.post(url, json=data, stream=True)
65
- client = sseclient.SSEClient(response)
 
 
 
 
 
 
 
66
 
67
- print("\n=== 流式调用响应 ===")
 
68
  output_buffer = []
69
  try:
70
- for event in client.events():
71
- try:
72
- data = json.loads(event.data)
73
- if data.get("done", True): # 如果是完成标记
74
- if "total_duration" in data: # 最终的性能统计消息
75
- print("\n=== 性能统计 ===")
76
- print(json.dumps(data, ensure_ascii=False, indent=2))
77
- break
78
- else: # 正常的内容消息
79
- message = data.get("message", {})
80
- content = message.get("content", "")
81
- if content: # 只收集非空内容
82
- output_buffer.append(content)
83
- except json.JSONDecodeError:
84
- print("Error decoding JSON from SSE event")
 
 
85
  finally:
86
  response.close() # 确保关闭响应连接
87
 
88
- # 一次性打印所有收集到的内容
89
- print("".join(output_buffer))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  if __name__ == "__main__":
92
- # 先测试非流式调用
93
- test_non_stream_chat()
94
 
95
- # 再测试流式调用
96
- test_stream_chat()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LightRAG Ollama 兼容接口测试脚本
3
+
4
+ 这个脚本测试 LightRAG 的 Ollama 兼容接口,包括:
5
+ 1. 基本功能测试(流式和非流式响应)
6
+ 2. 查询模式测试(local、global、naive、hybrid)
7
+ 3. 错误处理测试(包括流式和非流式场景)
8
+
9
+ 所有响应都使用 JSON Lines 格式,符合 Ollama API 规范。
10
+ """
11
+
12
  import requests
13
  import json
14
+ import argparse
15
+ import os
16
+ import time
17
+ from typing import Dict, Any, Optional, List, Callable, Tuple
18
+ from dataclasses import dataclass, asdict
19
+ from datetime import datetime
20
+ from pathlib import Path
21
 
22
+ class OutputControl:
23
+ """输出控制类,管理测试输出的详细程度"""
24
+ _verbose: bool = False
25
+
26
+ @classmethod
27
+ def set_verbose(cls, verbose: bool) -> None:
28
+ """设置输出详细程度
29
+
30
+ Args:
31
+ verbose: True 为详细模式,False 为静默模式
32
+ """
33
+ cls._verbose = verbose
34
+
35
+ @classmethod
36
+ def is_verbose(cls) -> bool:
37
+ """获取当前输出模式
38
+
39
+ Returns:
40
+ 当前是否为详细模式
41
+ """
42
+ return cls._verbose
43
+
44
+ @dataclass
45
+ class TestResult:
46
+ """测试结果数据类"""
47
+ name: str
48
+ success: bool
49
+ duration: float
50
+ error: Optional[str] = None
51
+ timestamp: str = ""
52
+
53
+ def __post_init__(self):
54
+ """初始化后设置时间戳"""
55
+ if not self.timestamp:
56
+ self.timestamp = datetime.now().isoformat()
57
+
58
+ class TestStats:
59
+ """测试统计信息"""
60
+ def __init__(self):
61
+ self.results: List[TestResult] = []
62
+ self.start_time = datetime.now()
63
+
64
+ def add_result(self, result: TestResult):
65
+ """添加测试结果"""
66
+ self.results.append(result)
67
+
68
+ def export_results(self, path: str = "test_results.json"):
69
+ """导出测试结果到 JSON 文件
70
+
71
+ Args:
72
+ path: 输出文件路径
73
+ """
74
+ results_data = {
75
+ "start_time": self.start_time.isoformat(),
76
+ "end_time": datetime.now().isoformat(),
77
+ "results": [asdict(r) for r in self.results],
78
+ "summary": {
79
+ "total": len(self.results),
80
+ "passed": sum(1 for r in self.results if r.success),
81
+ "failed": sum(1 for r in self.results if not r.success),
82
+ "total_duration": sum(r.duration for r in self.results)
83
+ }
84
+ }
85
+
86
+ with open(path, "w", encoding="utf-8") as f:
87
+ json.dump(results_data, f, ensure_ascii=False, indent=2)
88
+ print(f"\n测试结果已保存到: {path}")
89
 
90
+ def print_summary(self):
91
+ """打印测试统计摘要"""
92
+ total = len(self.results)
93
+ passed = sum(1 for r in self.results if r.success)
94
+ failed = total - passed
95
+ duration = sum(r.duration for r in self.results)
96
+
97
+ print("\n=== 测试结果摘要 ===")
98
+ print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
99
+ print(f"总用时: {duration:.2f}秒")
100
+ print(f"总计: {total} 个测试")
101
+ print(f"通过: {passed} 个")
102
+ print(f"失败: {failed} 个")
103
+
104
+ if failed > 0:
105
+ print("\n失败的测试:")
106
+ for result in self.results:
107
+ if not result.success:
108
+ print(f"- {result.name}: {result.error}")
109
+
110
+ # 默认配置
111
+ DEFAULT_CONFIG = {
112
+ "server": {
113
+ "host": "localhost",
114
+ "port": 9621,
115
  "model": "lightrag:latest",
116
+ "timeout": 30, # 请求超时时间(秒)
117
+ "max_retries": 3, # 最大重试次数
118
+ "retry_delay": 1 # 重试间隔(秒)
119
+ },
120
+ "test_cases": {
121
+ "basic": {
122
+ "query": "孙悟空",
123
+ "stream_query": "孙悟空有什么法力,性格特征是什么"
124
+ }
125
+ }
126
+ }
127
+
128
+ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
129
+ """发送 HTTP 请求,支持重试机制
130
+
131
+ Args:
132
+ url: 请求 URL
133
+ data: 请求数据
134
+ stream: 是否使用流式响应
135
+
136
+ Returns:
137
+ requests.Response 对象
138
+
139
+ Raises:
140
+ requests.exceptions.RequestException: 请求失败且重试次数用完
141
+ """
142
+ server_config = CONFIG["server"]
143
+ max_retries = server_config["max_retries"]
144
+ retry_delay = server_config["retry_delay"]
145
+ timeout = server_config["timeout"]
146
+
147
+ for attempt in range(max_retries):
148
+ try:
149
+ response = requests.post(
150
+ url,
151
+ json=data,
152
+ stream=stream,
153
+ timeout=timeout
154
+ )
155
+ return response
156
+ except requests.exceptions.RequestException as e:
157
+ if attempt == max_retries - 1: # 最后一次重试
158
+ raise
159
+ print(f"\n请求失败,{retry_delay}秒后重试: {str(e)}")
160
+ time.sleep(retry_delay)
161
+
162
+ def load_config() -> Dict[str, Any]:
163
+ """加载配置文件
164
+
165
+ 首先尝试从当前目录的 config.json 加载,
166
+ 如果不存在则使用默认配置
167
+
168
+ Returns:
169
+ 配置字典
170
+ """
171
+ config_path = Path("config.json")
172
+ if config_path.exists():
173
+ with open(config_path, "r", encoding="utf-8") as f:
174
+ return json.load(f)
175
+ return DEFAULT_CONFIG
176
+
177
+ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
178
+ """格式化打印 JSON 响应数据
179
+
180
+ Args:
181
+ data: 要打印的数据字典
182
+ title: 打印的标题
183
+ indent: JSON 缩进空格数
184
+ """
185
+ if OutputControl.is_verbose():
186
+ if title:
187
+ print(f"\n=== {title} ===")
188
+ print(json.dumps(data, ensure_ascii=False, indent=indent))
189
+
190
+ # 全局配置
191
+ CONFIG = load_config()
192
+
193
+ def get_base_url() -> str:
194
+ """返回基础 URL"""
195
+ server = CONFIG["server"]
196
+ return f"http://{server['host']}:{server['port']}/api/chat"
197
+
198
+ def create_request_data(
199
+ content: str,
200
+ stream: bool = False,
201
+ model: str = None
202
+ ) -> Dict[str, Any]:
203
+ """创建基本的请求数据
204
+
205
+ Args:
206
+ content: 用户消息内容
207
+ stream: 是否使用流式响应
208
+ model: 模型名称
209
+
210
+ Returns:
211
+ 包含完整请求数据的字典
212
+ """
213
+ return {
214
+ "model": model or CONFIG["server"]["model"],
215
  "messages": [
216
  {
217
  "role": "user",
218
+ "content": content
219
  }
220
  ],
221
+ "stream": stream
222
  }
223
+
224
+ # 全局测试统计
225
+ STATS = TestStats()
226
+
227
+ def run_test(func: Callable, name: str) -> None:
228
+ """运行测试并记录结果
229
+
230
+ Args:
231
+ func: 测试函数
232
+ name: 测试名称
233
+ """
234
+ start_time = time.time()
235
+ try:
236
+ func()
237
+ duration = time.time() - start_time
238
+ STATS.add_result(TestResult(name, True, duration))
239
+ except Exception as e:
240
+ duration = time.time() - start_time
241
+ STATS.add_result(TestResult(name, False, duration, str(e)))
242
+ raise
243
+
244
+ def test_non_stream_chat():
245
+ """测试非流式调用 /api/chat 接口"""
246
+ url = get_base_url()
247
+ data = create_request_data(
248
+ CONFIG["test_cases"]["basic"]["query"],
249
+ stream=False
250
+ )
251
 
252
  # 发送请求
253
+ response = make_request(url, data)
254
 
255
  # 打印响应
256
+ if OutputControl.is_verbose():
257
+ print("\n=== 非流式调用响应 ===")
258
  response_json = response.json()
259
 
260
+ # 打印响应内容
261
+ print_json_response({
 
262
  "model": response_json["model"],
263
  "message": response_json["message"]
264
+ }, "响应内容")
265
 
266
  # 打印性能统计
267
+ print_json_response({
 
268
  "total_duration": response_json["total_duration"],
269
  "load_duration": response_json["load_duration"],
270
  "prompt_eval_count": response_json["prompt_eval_count"],
271
  "prompt_eval_duration": response_json["prompt_eval_duration"],
272
  "eval_count": response_json["eval_count"],
273
  "eval_duration": response_json["eval_duration"]
274
+ }, "性能统计")
 
275
 
276
  def test_stream_chat():
277
+ """测试流式调用 /api/chat 接口
 
278
 
279
+ 使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
280
+ 响应格式:
281
+ {
282
  "model": "lightrag:latest",
283
+ "created_at": "2024-01-15T00:00:00Z",
284
+ "message": {
285
+ "role": "assistant",
286
+ "content": "部分响应内容",
287
+ "images": null
288
+ },
289
+ "done": false
290
  }
291
 
292
+ 最后一条消息会包含性能统计信息,done true。
293
+ """
294
+ url = get_base_url()
295
+ data = create_request_data(
296
+ CONFIG["test_cases"]["basic"]["stream_query"],
297
+ stream=True
298
+ )
299
+
300
+ # 发送请求并获取流式响应
301
+ response = make_request(url, data, stream=True)
302
 
303
+ if OutputControl.is_verbose():
304
+ print("\n=== 流式调用响应 ===")
305
  output_buffer = []
306
  try:
307
+ for line in response.iter_lines():
308
+ if line: # 跳过空行
309
+ try:
310
+ # 解码并解析 JSON
311
+ data = json.loads(line.decode('utf-8'))
312
+ if data.get("done", True): # 如果是完成标记
313
+ if "total_duration" in data: # 最终的性能统计消息
314
+ print_json_response(data, "性能统计")
315
+ break
316
+ else: # 正常的内容消息
317
+ message = data.get("message", {})
318
+ content = message.get("content", "")
319
+ if content: # 只收集非空内容
320
+ output_buffer.append(content)
321
+ print(content, end="", flush=True) # 实时打印内容
322
+ except json.JSONDecodeError:
323
+ print("Error decoding JSON from response line")
324
  finally:
325
  response.close() # 确保关闭响应连接
326
 
327
+ # 打印一个换行
328
+ print()
329
+
330
+ def test_query_modes():
331
+ """测试不同的查询模式前缀
332
+
333
+ 支持的查询模式:
334
+ - /local: 本地检索模式,只在相关度高的文档中搜索
335
+ - /global: 全局检索模式,在所有文档中搜索
336
+ - /naive: 朴素模式,不使用任何优化策略
337
+ - /hybrid: 混合模式(默认),结合多种策略
338
+
339
+ 每个模式都会返回相同格式的响应,但检索策略不同。
340
+ """
341
+ url = get_base_url()
342
+ modes = ["local", "global", "naive", "hybrid"] # 支持的查询模式
343
+
344
+ for mode in modes:
345
+ if OutputControl.is_verbose():
346
+ print(f"\n=== 测试 /{mode} 模式 ===")
347
+ data = create_request_data(
348
+ f"/{mode} 孙悟空的特点",
349
+ stream=False
350
+ )
351
+
352
+ # 发送请求
353
+ response = make_request(url, data)
354
+ response_json = response.json()
355
+
356
+ # 打印响应内容
357
+ print_json_response({
358
+ "model": response_json["model"],
359
+ "message": response_json["message"]
360
+ })
361
+
362
+ def create_error_test_data(error_type: str) -> Dict[str, Any]:
363
+ """创建用于错误测试的请求数据
364
+
365
+ Args:
366
+ error_type: 错误类型,支持:
367
+ - empty_messages: 空消息列表
368
+ - invalid_role: 无效的角色字段
369
+ - missing_content: 缺少内容字段
370
+
371
+ Returns:
372
+ 包含错误数据的请求字典
373
+ """
374
+ error_data = {
375
+ "empty_messages": {
376
+ "model": "lightrag:latest",
377
+ "messages": [],
378
+ "stream": True
379
+ },
380
+ "invalid_role": {
381
+ "model": "lightrag:latest",
382
+ "messages": [
383
+ {
384
+ "invalid_role": "user",
385
+ "content": "测试消息"
386
+ }
387
+ ],
388
+ "stream": True
389
+ },
390
+ "missing_content": {
391
+ "model": "lightrag:latest",
392
+ "messages": [
393
+ {
394
+ "role": "user"
395
+ }
396
+ ],
397
+ "stream": True
398
+ }
399
+ }
400
+ return error_data.get(error_type, error_data["empty_messages"])
401
+
402
+ def test_stream_error_handling():
403
+ """测试流式响应的错误处理
404
+
405
+ 测试场景:
406
+ 1. 空消息列表
407
+ 2. 消息格式错误(缺少必需字段)
408
+
409
+ 错误响应会立即返回,不会建立流式连接。
410
+ 状态码应该是 4xx,并返回详细的错误信息。
411
+ """
412
+ url = get_base_url()
413
+
414
+ if OutputControl.is_verbose():
415
+ print("\n=== 测试流式响应错误处理 ===")
416
+
417
+ # 测试空消息列表
418
+ if OutputControl.is_verbose():
419
+ print("\n--- 测试空消息列表(流式)---")
420
+ data = create_error_test_data("empty_messages")
421
+ response = make_request(url, data, stream=True)
422
+ print(f"状态码: {response.status_code}")
423
+ if response.status_code != 200:
424
+ print_json_response(response.json(), "错误信息")
425
+ response.close()
426
+
427
+ # 测试无效角色字段
428
+ if OutputControl.is_verbose():
429
+ print("\n--- 测试无效角色字段(流式)---")
430
+ data = create_error_test_data("invalid_role")
431
+ response = make_request(url, data, stream=True)
432
+ print(f"状态码: {response.status_code}")
433
+ if response.status_code != 200:
434
+ print_json_response(response.json(), "错误信息")
435
+ response.close()
436
+
437
+ # 测试缺少内容字段
438
+ if OutputControl.is_verbose():
439
+ print("\n--- 测试缺少内容字段(流式)---")
440
+ data = create_error_test_data("missing_content")
441
+ response = make_request(url, data, stream=True)
442
+ print(f"状态码: {response.status_code}")
443
+ if response.status_code != 200:
444
+ print_json_response(response.json(), "错误信息")
445
+ response.close()
446
+
447
+ def test_error_handling():
448
+ """测试非流式响应的错误处理
449
+
450
+ 测试场景:
451
+ 1. 空消息列表
452
+ 2. 消息格式错误(缺少必需字段)
453
+
454
+ 错误响应格式:
455
+ {
456
+ "detail": "错误描述"
457
+ }
458
+
459
+ 所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
460
+ """
461
+ url = get_base_url()
462
+
463
+ if OutputControl.is_verbose():
464
+ print("\n=== 测试错误处理 ===")
465
+
466
+ # 测试空消息列表
467
+ if OutputControl.is_verbose():
468
+ print("\n--- 测试空消息列表 ---")
469
+ data = create_error_test_data("empty_messages")
470
+ data["stream"] = False # 修改为非流式模式
471
+ response = make_request(url, data)
472
+ print(f"状态码: {response.status_code}")
473
+ print_json_response(response.json(), "错误信息")
474
+
475
+ # 测试无效角色字段
476
+ if OutputControl.is_verbose():
477
+ print("\n--- 测试无效角色字段 ---")
478
+ data = create_error_test_data("invalid_role")
479
+ data["stream"] = False # 修改为非流式模式
480
+ response = make_request(url, data)
481
+ print(f"状态码: {response.status_code}")
482
+ print_json_response(response.json(), "错误信息")
483
+
484
+ # 测试缺少内容字段
485
+ if OutputControl.is_verbose():
486
+ print("\n--- 测试缺少内容字段 ---")
487
+ data = create_error_test_data("missing_content")
488
+ data["stream"] = False # 修改为非流式模式
489
+ response = make_request(url, data)
490
+ print(f"状态码: {response.status_code}")
491
+ print_json_response(response.json(), "错误信息")
492
+
493
+ def get_test_cases() -> Dict[str, Callable]:
494
+ """获取所有可用的测试用例
495
+
496
+ Returns:
497
+ 测试名称到测试函数的映射字典
498
+ """
499
+ return {
500
+ "non_stream": test_non_stream_chat,
501
+ "stream": test_stream_chat,
502
+ "modes": test_query_modes,
503
+ "errors": test_error_handling,
504
+ "stream_errors": test_stream_error_handling
505
+ }
506
+
507
+ def create_default_config():
508
+ """创建默认配置文件"""
509
+ config_path = Path("config.json")
510
+ if not config_path.exists():
511
+ with open(config_path, "w", encoding="utf-8") as f:
512
+ json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
513
+ print(f"已创建默认配置文件: {config_path}")
514
+
515
+ def parse_args() -> argparse.Namespace:
516
+ """解析命令行参数"""
517
+ parser = argparse.ArgumentParser(
518
+ description="LightRAG Ollama 兼容接口测试",
519
+ formatter_class=argparse.RawDescriptionHelpFormatter,
520
+ epilog="""
521
+ 配置文件 (config.json):
522
+ {
523
+ "server": {
524
+ "host": "localhost", # 服务器地址
525
+ "port": 9621, # 服务器端口
526
+ "model": "lightrag:latest" # 默认模型名称
527
+ },
528
+ "test_cases": {
529
+ "basic": {
530
+ "query": "测试查询", # 基本查询文本
531
+ "stream_query": "流式查询" # 流式查询文本
532
+ }
533
+ }
534
+ }
535
+ """
536
+ )
537
+ parser.add_argument(
538
+ "--tests",
539
+ nargs="+",
540
+ choices=list(get_test_cases().keys()) + ["all"],
541
+ default=["all"],
542
+ help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试"
543
+ )
544
+ parser.add_argument(
545
+ "--init-config",
546
+ action="store_true",
547
+ help="创建默认配置文件"
548
+ )
549
+ parser.add_argument(
550
+ "--output",
551
+ type=str,
552
+ default="test_results.json",
553
+ help="测试结果输出文件路径"
554
+ )
555
+ parser.add_argument(
556
+ "-q", "--quiet",
557
+ action="store_true",
558
+ help="静默模式,只显示测试结果摘要"
559
+ )
560
+ return parser.parse_args()
561
 
562
  if __name__ == "__main__":
563
+ args = parse_args()
 
564
 
565
+ # 设置输出模式
566
+ OutputControl.set_verbose(not args.quiet)
567
+
568
+ # 如果指定了创建配置文件
569
+ if args.init_config:
570
+ create_default_config()
571
+ exit(0)
572
+
573
+ test_cases = get_test_cases()
574
+
575
+ try:
576
+ if "all" in args.tests:
577
+ # 运行所有测试
578
+ if OutputControl.is_verbose():
579
+ print("\n【基本功能测试】")
580
+ run_test(test_non_stream_chat, "非流式调用测试")
581
+ run_test(test_stream_chat, "流式调用测试")
582
+
583
+ if OutputControl.is_verbose():
584
+ print("\n【查询模式测试】")
585
+ run_test(test_query_modes, "查询模式测试")
586
+
587
+ if OutputControl.is_verbose():
588
+ print("\n【错误处理测试】")
589
+ run_test(test_error_handling, "错误处理测试")
590
+ run_test(test_stream_error_handling, "流式错误处理测试")
591
+ else:
592
+ # 运行指定的测试
593
+ for test_name in args.tests:
594
+ if OutputControl.is_verbose():
595
+ print(f"\n【运行测试: {test_name}】")
596
+ run_test(test_cases[test_name], test_name)
597
+ except Exception as e:
598
+ print(f"\n发生错误: {str(e)}")
599
+ finally:
600
+ # 打印并导出测试统计
601
+ STATS.print_summary()
602
+ STATS.export_results(args.output)