yangdx
commited on
Commit
·
9f80fd7
1
Parent(s):
49a6af5
完善测试用例
Browse files- .gitignore +1 -0
- test_lightrag_ollama_chat.py +560 -54
.gitignore
CHANGED
@@ -21,3 +21,4 @@ rag_storage
|
|
21 |
venv/
|
22 |
examples/input/
|
23 |
examples/output/
|
|
|
|
21 |
venv/
|
22 |
examples/input/
|
23 |
examples/output/
|
24 |
+
test_results.json
|
test_lightrag_ollama_chat.py
CHANGED
@@ -1,96 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import requests
|
2 |
import json
|
3 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
|
6 |
-
"""
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
"model": "lightrag:latest",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
"messages": [
|
13 |
{
|
14 |
"role": "user",
|
15 |
-
"content":
|
16 |
}
|
17 |
],
|
18 |
-
"stream":
|
19 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
# 发送请求
|
22 |
-
response =
|
23 |
|
24 |
# 打印响应
|
25 |
-
|
|
|
26 |
response_json = response.json()
|
27 |
|
28 |
-
#
|
29 |
-
|
30 |
-
print(json.dumps({
|
31 |
"model": response_json["model"],
|
32 |
"message": response_json["message"]
|
33 |
-
},
|
34 |
|
35 |
# 打印性能统计
|
36 |
-
|
37 |
-
stats = {
|
38 |
"total_duration": response_json["total_duration"],
|
39 |
"load_duration": response_json["load_duration"],
|
40 |
"prompt_eval_count": response_json["prompt_eval_count"],
|
41 |
"prompt_eval_duration": response_json["prompt_eval_duration"],
|
42 |
"eval_count": response_json["eval_count"],
|
43 |
"eval_duration": response_json["eval_duration"]
|
44 |
-
}
|
45 |
-
print(json.dumps(stats, ensure_ascii=False, indent=2))
|
46 |
|
47 |
def test_stream_chat():
|
48 |
-
"""测试流式调用 /api/chat 接口
|
49 |
-
url = "http://localhost:9621/api/chat"
|
50 |
|
51 |
-
|
52 |
-
|
|
|
53 |
"model": "lightrag:latest",
|
54 |
-
"
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
"
|
61 |
}
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
-
|
|
|
68 |
output_buffer = []
|
69 |
try:
|
70 |
-
for
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
85 |
finally:
|
86 |
response.close() # 确保关闭响应连接
|
87 |
|
88 |
-
#
|
89 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
if __name__ == "__main__":
|
92 |
-
|
93 |
-
test_non_stream_chat()
|
94 |
|
95 |
-
#
|
96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LightRAG Ollama 兼容接口测试脚本
|
3 |
+
|
4 |
+
这个脚本测试 LightRAG 的 Ollama 兼容接口,包括:
|
5 |
+
1. 基本功能测试(流式和非流式响应)
|
6 |
+
2. 查询模式测试(local、global、naive、hybrid)
|
7 |
+
3. 错误处理测试(包括流式和非流式场景)
|
8 |
+
|
9 |
+
所有响应都使用 JSON Lines 格式,符合 Ollama API 规范。
|
10 |
+
"""
|
11 |
+
|
12 |
import requests
|
13 |
import json
|
14 |
+
import argparse
|
15 |
+
import os
|
16 |
+
import time
|
17 |
+
from typing import Dict, Any, Optional, List, Callable, Tuple
|
18 |
+
from dataclasses import dataclass, asdict
|
19 |
+
from datetime import datetime
|
20 |
+
from pathlib import Path
|
21 |
|
22 |
+
class OutputControl:
|
23 |
+
"""输出控制类,管理测试输出的详细程度"""
|
24 |
+
_verbose: bool = False
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def set_verbose(cls, verbose: bool) -> None:
|
28 |
+
"""设置输出详细程度
|
29 |
+
|
30 |
+
Args:
|
31 |
+
verbose: True 为详细模式,False 为静默模式
|
32 |
+
"""
|
33 |
+
cls._verbose = verbose
|
34 |
+
|
35 |
+
@classmethod
|
36 |
+
def is_verbose(cls) -> bool:
|
37 |
+
"""获取当前输出模式
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
当前是否为详细模式
|
41 |
+
"""
|
42 |
+
return cls._verbose
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class TestResult:
|
46 |
+
"""测试结果数据类"""
|
47 |
+
name: str
|
48 |
+
success: bool
|
49 |
+
duration: float
|
50 |
+
error: Optional[str] = None
|
51 |
+
timestamp: str = ""
|
52 |
+
|
53 |
+
def __post_init__(self):
|
54 |
+
"""初始化后设置时间戳"""
|
55 |
+
if not self.timestamp:
|
56 |
+
self.timestamp = datetime.now().isoformat()
|
57 |
+
|
58 |
+
class TestStats:
|
59 |
+
"""测试统计信息"""
|
60 |
+
def __init__(self):
|
61 |
+
self.results: List[TestResult] = []
|
62 |
+
self.start_time = datetime.now()
|
63 |
+
|
64 |
+
def add_result(self, result: TestResult):
|
65 |
+
"""添加测试结果"""
|
66 |
+
self.results.append(result)
|
67 |
+
|
68 |
+
def export_results(self, path: str = "test_results.json"):
|
69 |
+
"""导出测试结果到 JSON 文件
|
70 |
+
|
71 |
+
Args:
|
72 |
+
path: 输出文件路径
|
73 |
+
"""
|
74 |
+
results_data = {
|
75 |
+
"start_time": self.start_time.isoformat(),
|
76 |
+
"end_time": datetime.now().isoformat(),
|
77 |
+
"results": [asdict(r) for r in self.results],
|
78 |
+
"summary": {
|
79 |
+
"total": len(self.results),
|
80 |
+
"passed": sum(1 for r in self.results if r.success),
|
81 |
+
"failed": sum(1 for r in self.results if not r.success),
|
82 |
+
"total_duration": sum(r.duration for r in self.results)
|
83 |
+
}
|
84 |
+
}
|
85 |
+
|
86 |
+
with open(path, "w", encoding="utf-8") as f:
|
87 |
+
json.dump(results_data, f, ensure_ascii=False, indent=2)
|
88 |
+
print(f"\n测试结果已保存到: {path}")
|
89 |
|
90 |
+
def print_summary(self):
|
91 |
+
"""打印测试统计摘要"""
|
92 |
+
total = len(self.results)
|
93 |
+
passed = sum(1 for r in self.results if r.success)
|
94 |
+
failed = total - passed
|
95 |
+
duration = sum(r.duration for r in self.results)
|
96 |
+
|
97 |
+
print("\n=== 测试结果摘要 ===")
|
98 |
+
print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
99 |
+
print(f"总用时: {duration:.2f}秒")
|
100 |
+
print(f"总计: {total} 个测试")
|
101 |
+
print(f"通过: {passed} 个")
|
102 |
+
print(f"失败: {failed} 个")
|
103 |
+
|
104 |
+
if failed > 0:
|
105 |
+
print("\n失败的测试:")
|
106 |
+
for result in self.results:
|
107 |
+
if not result.success:
|
108 |
+
print(f"- {result.name}: {result.error}")
|
109 |
+
|
110 |
+
# 默认配置
|
111 |
+
DEFAULT_CONFIG = {
|
112 |
+
"server": {
|
113 |
+
"host": "localhost",
|
114 |
+
"port": 9621,
|
115 |
"model": "lightrag:latest",
|
116 |
+
"timeout": 30, # 请求超时时间(秒)
|
117 |
+
"max_retries": 3, # 最大重试次数
|
118 |
+
"retry_delay": 1 # 重试间隔(秒)
|
119 |
+
},
|
120 |
+
"test_cases": {
|
121 |
+
"basic": {
|
122 |
+
"query": "孙悟空",
|
123 |
+
"stream_query": "孙悟空有什么法力,性格特征是什么"
|
124 |
+
}
|
125 |
+
}
|
126 |
+
}
|
127 |
+
|
128 |
+
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
|
129 |
+
"""发送 HTTP 请求,支持重试机制
|
130 |
+
|
131 |
+
Args:
|
132 |
+
url: 请求 URL
|
133 |
+
data: 请求数据
|
134 |
+
stream: 是否使用流式响应
|
135 |
+
|
136 |
+
Returns:
|
137 |
+
requests.Response 对象
|
138 |
+
|
139 |
+
Raises:
|
140 |
+
requests.exceptions.RequestException: 请求失败且重试次数用完
|
141 |
+
"""
|
142 |
+
server_config = CONFIG["server"]
|
143 |
+
max_retries = server_config["max_retries"]
|
144 |
+
retry_delay = server_config["retry_delay"]
|
145 |
+
timeout = server_config["timeout"]
|
146 |
+
|
147 |
+
for attempt in range(max_retries):
|
148 |
+
try:
|
149 |
+
response = requests.post(
|
150 |
+
url,
|
151 |
+
json=data,
|
152 |
+
stream=stream,
|
153 |
+
timeout=timeout
|
154 |
+
)
|
155 |
+
return response
|
156 |
+
except requests.exceptions.RequestException as e:
|
157 |
+
if attempt == max_retries - 1: # 最后一次重试
|
158 |
+
raise
|
159 |
+
print(f"\n请求失败,{retry_delay}秒后重试: {str(e)}")
|
160 |
+
time.sleep(retry_delay)
|
161 |
+
|
162 |
+
def load_config() -> Dict[str, Any]:
|
163 |
+
"""加载配置文件
|
164 |
+
|
165 |
+
首先尝试从当前目录的 config.json 加载,
|
166 |
+
如果不存在则使用默认配置
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
配置字典
|
170 |
+
"""
|
171 |
+
config_path = Path("config.json")
|
172 |
+
if config_path.exists():
|
173 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
174 |
+
return json.load(f)
|
175 |
+
return DEFAULT_CONFIG
|
176 |
+
|
177 |
+
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
178 |
+
"""格式化打印 JSON 响应数据
|
179 |
+
|
180 |
+
Args:
|
181 |
+
data: 要打印的数据字典
|
182 |
+
title: 打印的标题
|
183 |
+
indent: JSON 缩进空格数
|
184 |
+
"""
|
185 |
+
if OutputControl.is_verbose():
|
186 |
+
if title:
|
187 |
+
print(f"\n=== {title} ===")
|
188 |
+
print(json.dumps(data, ensure_ascii=False, indent=indent))
|
189 |
+
|
190 |
+
# 全局配置
|
191 |
+
CONFIG = load_config()
|
192 |
+
|
193 |
+
def get_base_url() -> str:
|
194 |
+
"""返回基础 URL"""
|
195 |
+
server = CONFIG["server"]
|
196 |
+
return f"http://{server['host']}:{server['port']}/api/chat"
|
197 |
+
|
198 |
+
def create_request_data(
|
199 |
+
content: str,
|
200 |
+
stream: bool = False,
|
201 |
+
model: str = None
|
202 |
+
) -> Dict[str, Any]:
|
203 |
+
"""创建基本的请求数据
|
204 |
+
|
205 |
+
Args:
|
206 |
+
content: 用户消息内容
|
207 |
+
stream: 是否使用流式响应
|
208 |
+
model: 模型名称
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
包含完整请求数据的字典
|
212 |
+
"""
|
213 |
+
return {
|
214 |
+
"model": model or CONFIG["server"]["model"],
|
215 |
"messages": [
|
216 |
{
|
217 |
"role": "user",
|
218 |
+
"content": content
|
219 |
}
|
220 |
],
|
221 |
+
"stream": stream
|
222 |
}
|
223 |
+
|
224 |
+
# 全局测试统计
|
225 |
+
STATS = TestStats()
|
226 |
+
|
227 |
+
def run_test(func: Callable, name: str) -> None:
|
228 |
+
"""运行测试并记录结果
|
229 |
+
|
230 |
+
Args:
|
231 |
+
func: 测试函数
|
232 |
+
name: 测试名称
|
233 |
+
"""
|
234 |
+
start_time = time.time()
|
235 |
+
try:
|
236 |
+
func()
|
237 |
+
duration = time.time() - start_time
|
238 |
+
STATS.add_result(TestResult(name, True, duration))
|
239 |
+
except Exception as e:
|
240 |
+
duration = time.time() - start_time
|
241 |
+
STATS.add_result(TestResult(name, False, duration, str(e)))
|
242 |
+
raise
|
243 |
+
|
244 |
+
def test_non_stream_chat():
|
245 |
+
"""测试非流式调用 /api/chat 接口"""
|
246 |
+
url = get_base_url()
|
247 |
+
data = create_request_data(
|
248 |
+
CONFIG["test_cases"]["basic"]["query"],
|
249 |
+
stream=False
|
250 |
+
)
|
251 |
|
252 |
# 发送请求
|
253 |
+
response = make_request(url, data)
|
254 |
|
255 |
# 打印响应
|
256 |
+
if OutputControl.is_verbose():
|
257 |
+
print("\n=== 非流式调用响应 ===")
|
258 |
response_json = response.json()
|
259 |
|
260 |
+
# 打印响应内容
|
261 |
+
print_json_response({
|
|
|
262 |
"model": response_json["model"],
|
263 |
"message": response_json["message"]
|
264 |
+
}, "响应内容")
|
265 |
|
266 |
# 打印性能统计
|
267 |
+
print_json_response({
|
|
|
268 |
"total_duration": response_json["total_duration"],
|
269 |
"load_duration": response_json["load_duration"],
|
270 |
"prompt_eval_count": response_json["prompt_eval_count"],
|
271 |
"prompt_eval_duration": response_json["prompt_eval_duration"],
|
272 |
"eval_count": response_json["eval_count"],
|
273 |
"eval_duration": response_json["eval_duration"]
|
274 |
+
}, "性能统计")
|
|
|
275 |
|
276 |
def test_stream_chat():
|
277 |
+
"""测试流式调用 /api/chat 接口
|
|
|
278 |
|
279 |
+
使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
|
280 |
+
响应格式:
|
281 |
+
{
|
282 |
"model": "lightrag:latest",
|
283 |
+
"created_at": "2024-01-15T00:00:00Z",
|
284 |
+
"message": {
|
285 |
+
"role": "assistant",
|
286 |
+
"content": "部分响应内容",
|
287 |
+
"images": null
|
288 |
+
},
|
289 |
+
"done": false
|
290 |
}
|
291 |
|
292 |
+
最后一条消息会包含性能统计信息,done 为 true。
|
293 |
+
"""
|
294 |
+
url = get_base_url()
|
295 |
+
data = create_request_data(
|
296 |
+
CONFIG["test_cases"]["basic"]["stream_query"],
|
297 |
+
stream=True
|
298 |
+
)
|
299 |
+
|
300 |
+
# 发送请求并获取流式响应
|
301 |
+
response = make_request(url, data, stream=True)
|
302 |
|
303 |
+
if OutputControl.is_verbose():
|
304 |
+
print("\n=== 流式调用响应 ===")
|
305 |
output_buffer = []
|
306 |
try:
|
307 |
+
for line in response.iter_lines():
|
308 |
+
if line: # 跳过空行
|
309 |
+
try:
|
310 |
+
# 解码并解析 JSON
|
311 |
+
data = json.loads(line.decode('utf-8'))
|
312 |
+
if data.get("done", True): # 如果是完成标记
|
313 |
+
if "total_duration" in data: # 最终的性能统计消息
|
314 |
+
print_json_response(data, "性能统计")
|
315 |
+
break
|
316 |
+
else: # 正常的内容消息
|
317 |
+
message = data.get("message", {})
|
318 |
+
content = message.get("content", "")
|
319 |
+
if content: # 只收集非空内容
|
320 |
+
output_buffer.append(content)
|
321 |
+
print(content, end="", flush=True) # 实时打印内容
|
322 |
+
except json.JSONDecodeError:
|
323 |
+
print("Error decoding JSON from response line")
|
324 |
finally:
|
325 |
response.close() # 确保关闭响应连接
|
326 |
|
327 |
+
# 打印一个换行
|
328 |
+
print()
|
329 |
+
|
330 |
+
def test_query_modes():
|
331 |
+
"""测试不同的查询模式前缀
|
332 |
+
|
333 |
+
支持的查询模式:
|
334 |
+
- /local: 本地检索模式,只在相关度高的文档中搜索
|
335 |
+
- /global: 全局检索模式,在所有文档中搜索
|
336 |
+
- /naive: 朴素模式,不使用任何优化策略
|
337 |
+
- /hybrid: 混合模式(默认),结合多种策略
|
338 |
+
|
339 |
+
每个模式都会返回相同格式的响应,但检索策略不同。
|
340 |
+
"""
|
341 |
+
url = get_base_url()
|
342 |
+
modes = ["local", "global", "naive", "hybrid"] # 支持的查询模式
|
343 |
+
|
344 |
+
for mode in modes:
|
345 |
+
if OutputControl.is_verbose():
|
346 |
+
print(f"\n=== 测试 /{mode} 模式 ===")
|
347 |
+
data = create_request_data(
|
348 |
+
f"/{mode} 孙悟空的特点",
|
349 |
+
stream=False
|
350 |
+
)
|
351 |
+
|
352 |
+
# 发送请求
|
353 |
+
response = make_request(url, data)
|
354 |
+
response_json = response.json()
|
355 |
+
|
356 |
+
# 打印响应内容
|
357 |
+
print_json_response({
|
358 |
+
"model": response_json["model"],
|
359 |
+
"message": response_json["message"]
|
360 |
+
})
|
361 |
+
|
362 |
+
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
363 |
+
"""创建用于错误测试的请求数据
|
364 |
+
|
365 |
+
Args:
|
366 |
+
error_type: 错误类型,支持:
|
367 |
+
- empty_messages: 空消息列表
|
368 |
+
- invalid_role: 无效的角色字段
|
369 |
+
- missing_content: 缺少内容字段
|
370 |
+
|
371 |
+
Returns:
|
372 |
+
包含错误数据的请求字典
|
373 |
+
"""
|
374 |
+
error_data = {
|
375 |
+
"empty_messages": {
|
376 |
+
"model": "lightrag:latest",
|
377 |
+
"messages": [],
|
378 |
+
"stream": True
|
379 |
+
},
|
380 |
+
"invalid_role": {
|
381 |
+
"model": "lightrag:latest",
|
382 |
+
"messages": [
|
383 |
+
{
|
384 |
+
"invalid_role": "user",
|
385 |
+
"content": "测试消息"
|
386 |
+
}
|
387 |
+
],
|
388 |
+
"stream": True
|
389 |
+
},
|
390 |
+
"missing_content": {
|
391 |
+
"model": "lightrag:latest",
|
392 |
+
"messages": [
|
393 |
+
{
|
394 |
+
"role": "user"
|
395 |
+
}
|
396 |
+
],
|
397 |
+
"stream": True
|
398 |
+
}
|
399 |
+
}
|
400 |
+
return error_data.get(error_type, error_data["empty_messages"])
|
401 |
+
|
402 |
+
def test_stream_error_handling():
|
403 |
+
"""测试流式响应的错误处理
|
404 |
+
|
405 |
+
测试场景:
|
406 |
+
1. 空消息列表
|
407 |
+
2. 消息格式错误(缺少必需字段)
|
408 |
+
|
409 |
+
错误响应会立即返回,不会建立流式连接。
|
410 |
+
状态码应该是 4xx,并返回详细的错误信息。
|
411 |
+
"""
|
412 |
+
url = get_base_url()
|
413 |
+
|
414 |
+
if OutputControl.is_verbose():
|
415 |
+
print("\n=== 测试流式响应错误处理 ===")
|
416 |
+
|
417 |
+
# 测试空消息列表
|
418 |
+
if OutputControl.is_verbose():
|
419 |
+
print("\n--- 测试空消息列表(流式)---")
|
420 |
+
data = create_error_test_data("empty_messages")
|
421 |
+
response = make_request(url, data, stream=True)
|
422 |
+
print(f"状态码: {response.status_code}")
|
423 |
+
if response.status_code != 200:
|
424 |
+
print_json_response(response.json(), "错误信息")
|
425 |
+
response.close()
|
426 |
+
|
427 |
+
# 测试无效角色字段
|
428 |
+
if OutputControl.is_verbose():
|
429 |
+
print("\n--- 测试无效角色字段(流式)---")
|
430 |
+
data = create_error_test_data("invalid_role")
|
431 |
+
response = make_request(url, data, stream=True)
|
432 |
+
print(f"状态码: {response.status_code}")
|
433 |
+
if response.status_code != 200:
|
434 |
+
print_json_response(response.json(), "错误信息")
|
435 |
+
response.close()
|
436 |
+
|
437 |
+
# 测试缺少内容字段
|
438 |
+
if OutputControl.is_verbose():
|
439 |
+
print("\n--- 测试缺少内容字段(流式)---")
|
440 |
+
data = create_error_test_data("missing_content")
|
441 |
+
response = make_request(url, data, stream=True)
|
442 |
+
print(f"状态码: {response.status_code}")
|
443 |
+
if response.status_code != 200:
|
444 |
+
print_json_response(response.json(), "错误信息")
|
445 |
+
response.close()
|
446 |
+
|
447 |
+
def test_error_handling():
|
448 |
+
"""测试非流式响应的错误处理
|
449 |
+
|
450 |
+
测试场景:
|
451 |
+
1. 空消息列表
|
452 |
+
2. 消息格式错误(缺少必需字段)
|
453 |
+
|
454 |
+
错误响应格式:
|
455 |
+
{
|
456 |
+
"detail": "错误描述"
|
457 |
+
}
|
458 |
+
|
459 |
+
所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
|
460 |
+
"""
|
461 |
+
url = get_base_url()
|
462 |
+
|
463 |
+
if OutputControl.is_verbose():
|
464 |
+
print("\n=== 测试错误处理 ===")
|
465 |
+
|
466 |
+
# 测试空消息列表
|
467 |
+
if OutputControl.is_verbose():
|
468 |
+
print("\n--- 测试空消息列表 ---")
|
469 |
+
data = create_error_test_data("empty_messages")
|
470 |
+
data["stream"] = False # 修改为非流式模式
|
471 |
+
response = make_request(url, data)
|
472 |
+
print(f"状态码: {response.status_code}")
|
473 |
+
print_json_response(response.json(), "错误信息")
|
474 |
+
|
475 |
+
# 测试无效角色字段
|
476 |
+
if OutputControl.is_verbose():
|
477 |
+
print("\n--- 测试无效角色字段 ---")
|
478 |
+
data = create_error_test_data("invalid_role")
|
479 |
+
data["stream"] = False # 修改为非流式模式
|
480 |
+
response = make_request(url, data)
|
481 |
+
print(f"状态码: {response.status_code}")
|
482 |
+
print_json_response(response.json(), "错误信息")
|
483 |
+
|
484 |
+
# 测试缺少内容字段
|
485 |
+
if OutputControl.is_verbose():
|
486 |
+
print("\n--- 测试缺少内容字段 ---")
|
487 |
+
data = create_error_test_data("missing_content")
|
488 |
+
data["stream"] = False # 修改为非流式模式
|
489 |
+
response = make_request(url, data)
|
490 |
+
print(f"状态码: {response.status_code}")
|
491 |
+
print_json_response(response.json(), "错误信息")
|
492 |
+
|
493 |
+
def get_test_cases() -> Dict[str, Callable]:
|
494 |
+
"""获取所有可用的测试用例
|
495 |
+
|
496 |
+
Returns:
|
497 |
+
测试名称到测试函数的映射字典
|
498 |
+
"""
|
499 |
+
return {
|
500 |
+
"non_stream": test_non_stream_chat,
|
501 |
+
"stream": test_stream_chat,
|
502 |
+
"modes": test_query_modes,
|
503 |
+
"errors": test_error_handling,
|
504 |
+
"stream_errors": test_stream_error_handling
|
505 |
+
}
|
506 |
+
|
507 |
+
def create_default_config():
|
508 |
+
"""创建默认配置文件"""
|
509 |
+
config_path = Path("config.json")
|
510 |
+
if not config_path.exists():
|
511 |
+
with open(config_path, "w", encoding="utf-8") as f:
|
512 |
+
json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
|
513 |
+
print(f"已创建默认配置文件: {config_path}")
|
514 |
+
|
515 |
+
def parse_args() -> argparse.Namespace:
|
516 |
+
"""解析命令行参数"""
|
517 |
+
parser = argparse.ArgumentParser(
|
518 |
+
description="LightRAG Ollama 兼容接口测试",
|
519 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
520 |
+
epilog="""
|
521 |
+
配置文件 (config.json):
|
522 |
+
{
|
523 |
+
"server": {
|
524 |
+
"host": "localhost", # 服务器地址
|
525 |
+
"port": 9621, # 服务器端口
|
526 |
+
"model": "lightrag:latest" # 默认模型名称
|
527 |
+
},
|
528 |
+
"test_cases": {
|
529 |
+
"basic": {
|
530 |
+
"query": "测试查询", # 基本查询文本
|
531 |
+
"stream_query": "流式查询" # 流式查询文本
|
532 |
+
}
|
533 |
+
}
|
534 |
+
}
|
535 |
+
"""
|
536 |
+
)
|
537 |
+
parser.add_argument(
|
538 |
+
"--tests",
|
539 |
+
nargs="+",
|
540 |
+
choices=list(get_test_cases().keys()) + ["all"],
|
541 |
+
default=["all"],
|
542 |
+
help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试"
|
543 |
+
)
|
544 |
+
parser.add_argument(
|
545 |
+
"--init-config",
|
546 |
+
action="store_true",
|
547 |
+
help="创建默认配置文件"
|
548 |
+
)
|
549 |
+
parser.add_argument(
|
550 |
+
"--output",
|
551 |
+
type=str,
|
552 |
+
default="test_results.json",
|
553 |
+
help="测试结果输出文件路径"
|
554 |
+
)
|
555 |
+
parser.add_argument(
|
556 |
+
"-q", "--quiet",
|
557 |
+
action="store_true",
|
558 |
+
help="静默模式,只显示测试结果摘要"
|
559 |
+
)
|
560 |
+
return parser.parse_args()
|
561 |
|
562 |
if __name__ == "__main__":
|
563 |
+
args = parse_args()
|
|
|
564 |
|
565 |
+
# 设置输出模式
|
566 |
+
OutputControl.set_verbose(not args.quiet)
|
567 |
+
|
568 |
+
# 如果指定了创建配置文件
|
569 |
+
if args.init_config:
|
570 |
+
create_default_config()
|
571 |
+
exit(0)
|
572 |
+
|
573 |
+
test_cases = get_test_cases()
|
574 |
+
|
575 |
+
try:
|
576 |
+
if "all" in args.tests:
|
577 |
+
# 运行所有测试
|
578 |
+
if OutputControl.is_verbose():
|
579 |
+
print("\n【基本功能测试】")
|
580 |
+
run_test(test_non_stream_chat, "非流式调用测试")
|
581 |
+
run_test(test_stream_chat, "流式调用测试")
|
582 |
+
|
583 |
+
if OutputControl.is_verbose():
|
584 |
+
print("\n【查询模式测试】")
|
585 |
+
run_test(test_query_modes, "查询模式测试")
|
586 |
+
|
587 |
+
if OutputControl.is_verbose():
|
588 |
+
print("\n【错误处理测试】")
|
589 |
+
run_test(test_error_handling, "错误处理测试")
|
590 |
+
run_test(test_stream_error_handling, "流式错误处理测试")
|
591 |
+
else:
|
592 |
+
# 运行指定的测试
|
593 |
+
for test_name in args.tests:
|
594 |
+
if OutputControl.is_verbose():
|
595 |
+
print(f"\n【运行测试: {test_name}】")
|
596 |
+
run_test(test_cases[test_name], test_name)
|
597 |
+
except Exception as e:
|
598 |
+
print(f"\n发生错误: {str(e)}")
|
599 |
+
finally:
|
600 |
+
# 打印并导出测试统计
|
601 |
+
STATS.print_summary()
|
602 |
+
STATS.export_results(args.output)
|