yangdx
commited on
Commit
·
99a3d9e
1
Parent(s):
528d6fd
Translate comment to English
Browse files- lightrag/api/lightrag_ollama.py +17 -20
- test_lightrag_ollama_chat.py +60 -73
lightrag/api/lightrag_ollama.py
CHANGED
@@ -27,15 +27,15 @@ from dotenv import load_dotenv
|
|
27 |
load_dotenv()
|
28 |
|
29 |
def estimate_tokens(text: str) -> int:
|
30 |
-
"""
|
31 |
-
|
32 |
-
|
33 |
"""
|
34 |
-
#
|
35 |
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
|
36 |
non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text))
|
37 |
|
38 |
-
#
|
39 |
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
|
40 |
|
41 |
return int(tokens)
|
@@ -241,7 +241,7 @@ class DocumentManager:
|
|
241 |
class SearchMode(str, Enum):
|
242 |
naive = "naive"
|
243 |
local = "local"
|
244 |
-
global_ = "global" #
|
245 |
hybrid = "hybrid"
|
246 |
mix = "mix"
|
247 |
|
@@ -254,7 +254,7 @@ class OllamaMessage(BaseModel):
|
|
254 |
class OllamaChatRequest(BaseModel):
|
255 |
model: str = LIGHTRAG_MODEL
|
256 |
messages: List[OllamaMessage]
|
257 |
-
stream: bool = True #
|
258 |
options: Optional[Dict[str, Any]] = None
|
259 |
|
260 |
class OllamaChatResponse(BaseModel):
|
@@ -490,11 +490,11 @@ def create_app(args):
|
|
490 |
),
|
491 |
)
|
492 |
|
493 |
-
#
|
494 |
if isinstance(response, str):
|
495 |
return QueryResponse(response=response)
|
496 |
|
497 |
-
#
|
498 |
if request.stream:
|
499 |
result = ""
|
500 |
async for chunk in response:
|
@@ -511,7 +511,7 @@ def create_app(args):
|
|
511 |
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
512 |
async def query_text_stream(request: QueryRequest):
|
513 |
try:
|
514 |
-
response = await rag.aquery( #
|
515 |
request.query,
|
516 |
param=QueryParam(
|
517 |
mode=request.mode,
|
@@ -691,7 +691,7 @@ def create_app(args):
|
|
691 |
|
692 |
for prefix, mode in mode_map.items():
|
693 |
if query.startswith(prefix):
|
694 |
-
#
|
695 |
cleaned_query = query[len(prefix):].lstrip()
|
696 |
return cleaned_query, mode
|
697 |
|
@@ -699,17 +699,14 @@ def create_app(args):
|
|
699 |
|
700 |
@app.post("/api/chat")
|
701 |
async def chat(raw_request: Request, request: OllamaChatRequest):
|
702 |
-
# # 打印原始请求数据
|
703 |
-
# body = await raw_request.body()
|
704 |
-
# logging.info(f"收到 /api/chat 原始请求: {body.decode('utf-8')}")
|
705 |
"""Handle chat completion requests"""
|
706 |
try:
|
707 |
-
#
|
708 |
messages = request.messages
|
709 |
if not messages:
|
710 |
raise HTTPException(status_code=400, detail="No messages provided")
|
711 |
|
712 |
-
#
|
713 |
query = messages[-1].content
|
714 |
|
715 |
# 解析查询模式
|
@@ -723,7 +720,7 @@ def create_app(args):
|
|
723 |
|
724 |
# 调用RAG进行查询
|
725 |
query_param = QueryParam(
|
726 |
-
mode=mode,
|
727 |
stream=request.stream,
|
728 |
only_need_context=False
|
729 |
)
|
@@ -731,7 +728,7 @@ def create_app(args):
|
|
731 |
if request.stream:
|
732 |
from fastapi.responses import StreamingResponse
|
733 |
|
734 |
-
response = await rag.aquery( #
|
735 |
cleaned_query,
|
736 |
param=query_param
|
737 |
)
|
@@ -742,9 +739,9 @@ def create_app(args):
|
|
742 |
last_chunk_time = None
|
743 |
total_response = ""
|
744 |
|
745 |
-
#
|
746 |
if isinstance(response, str):
|
747 |
-
#
|
748 |
first_chunk_time = time.time_ns()
|
749 |
last_chunk_time = first_chunk_time
|
750 |
total_response = response
|
|
|
27 |
load_dotenv()
|
28 |
|
29 |
def estimate_tokens(text: str) -> int:
|
30 |
+
"""Estimate the number of tokens in text
|
31 |
+
Chinese characters: approximately 1.5 tokens per character
|
32 |
+
English characters: approximately 0.25 tokens per character
|
33 |
"""
|
34 |
+
# Use regex to match Chinese and non-Chinese characters separately
|
35 |
chinese_chars = len(re.findall(r'[\u4e00-\u9fff]', text))
|
36 |
non_chinese_chars = len(re.findall(r'[^\u4e00-\u9fff]', text))
|
37 |
|
38 |
+
# Calculate estimated token count
|
39 |
tokens = chinese_chars * 1.5 + non_chinese_chars * 0.25
|
40 |
|
41 |
return int(tokens)
|
|
|
241 |
class SearchMode(str, Enum):
|
242 |
naive = "naive"
|
243 |
local = "local"
|
244 |
+
global_ = "global" # Using global_ because global is a Python reserved keyword, but enum value will be converted to string "global"
|
245 |
hybrid = "hybrid"
|
246 |
mix = "mix"
|
247 |
|
|
|
254 |
class OllamaChatRequest(BaseModel):
|
255 |
model: str = LIGHTRAG_MODEL
|
256 |
messages: List[OllamaMessage]
|
257 |
+
stream: bool = True # Default to streaming mode
|
258 |
options: Optional[Dict[str, Any]] = None
|
259 |
|
260 |
class OllamaChatResponse(BaseModel):
|
|
|
490 |
),
|
491 |
)
|
492 |
|
493 |
+
# If response is a string (e.g. cache hit), return directly
|
494 |
if isinstance(response, str):
|
495 |
return QueryResponse(response=response)
|
496 |
|
497 |
+
# If it's an async generator, decide whether to stream based on stream parameter
|
498 |
if request.stream:
|
499 |
result = ""
|
500 |
async for chunk in response:
|
|
|
511 |
@app.post("/query/stream", dependencies=[Depends(optional_api_key)])
|
512 |
async def query_text_stream(request: QueryRequest):
|
513 |
try:
|
514 |
+
response = await rag.aquery( # Use aquery instead of query, and add await
|
515 |
request.query,
|
516 |
param=QueryParam(
|
517 |
mode=request.mode,
|
|
|
691 |
|
692 |
for prefix, mode in mode_map.items():
|
693 |
if query.startswith(prefix):
|
694 |
+
# After removing prefix an leading spaces
|
695 |
cleaned_query = query[len(prefix):].lstrip()
|
696 |
return cleaned_query, mode
|
697 |
|
|
|
699 |
|
700 |
@app.post("/api/chat")
|
701 |
async def chat(raw_request: Request, request: OllamaChatRequest):
|
|
|
|
|
|
|
702 |
"""Handle chat completion requests"""
|
703 |
try:
|
704 |
+
# Get all messages
|
705 |
messages = request.messages
|
706 |
if not messages:
|
707 |
raise HTTPException(status_code=400, detail="No messages provided")
|
708 |
|
709 |
+
# Get the last message as query
|
710 |
query = messages[-1].content
|
711 |
|
712 |
# 解析查询模式
|
|
|
720 |
|
721 |
# 调用RAG进行查询
|
722 |
query_param = QueryParam(
|
723 |
+
mode=mode,
|
724 |
stream=request.stream,
|
725 |
only_need_context=False
|
726 |
)
|
|
|
728 |
if request.stream:
|
729 |
from fastapi.responses import StreamingResponse
|
730 |
|
731 |
+
response = await rag.aquery( # Need await to get async generator
|
732 |
cleaned_query,
|
733 |
param=query_param
|
734 |
)
|
|
|
739 |
last_chunk_time = None
|
740 |
total_response = ""
|
741 |
|
742 |
+
# Ensure response is an async generator
|
743 |
if isinstance(response, str):
|
744 |
+
# If it's a string, send in two parts
|
745 |
first_chunk_time = time.time_ns()
|
746 |
last_chunk_time = first_chunk_time
|
747 |
total_response = response
|
test_lightrag_ollama_chat.py
CHANGED
@@ -1,12 +1,12 @@
|
|
1 |
"""
|
2 |
-
LightRAG Ollama
|
3 |
|
4 |
-
|
5 |
-
1.
|
6 |
-
2.
|
7 |
-
3.
|
8 |
|
9 |
-
|
10 |
"""
|
11 |
|
12 |
import requests
|
@@ -24,20 +24,10 @@ class OutputControl:
|
|
24 |
|
25 |
@classmethod
|
26 |
def set_verbose(cls, verbose: bool) -> None:
|
27 |
-
"""设置输出详细程度
|
28 |
-
|
29 |
-
Args:
|
30 |
-
verbose: True 为详细模式,False 为静默模式
|
31 |
-
"""
|
32 |
cls._verbose = verbose
|
33 |
|
34 |
@classmethod
|
35 |
def is_verbose(cls) -> bool:
|
36 |
-
"""获取当前输出模式
|
37 |
-
|
38 |
-
Returns:
|
39 |
-
当前是否为详细模式
|
40 |
-
"""
|
41 |
return cls._verbose
|
42 |
|
43 |
@dataclass
|
@@ -48,9 +38,8 @@ class TestResult:
|
|
48 |
duration: float
|
49 |
error: Optional[str] = None
|
50 |
timestamp: str = ""
|
51 |
-
|
52 |
def __post_init__(self):
|
53 |
-
"""初始化后设置时间戳"""
|
54 |
if not self.timestamp:
|
55 |
self.timestamp = datetime.now().isoformat()
|
56 |
|
@@ -59,14 +48,13 @@ class TestStats:
|
|
59 |
def __init__(self):
|
60 |
self.results: List[TestResult] = []
|
61 |
self.start_time = datetime.now()
|
62 |
-
|
63 |
def add_result(self, result: TestResult):
|
64 |
-
"""添加测试结果"""
|
65 |
self.results.append(result)
|
66 |
-
|
67 |
def export_results(self, path: str = "test_results.json"):
|
68 |
"""导出测试结果到 JSON 文件
|
69 |
-
|
70 |
Args:
|
71 |
path: 输出文件路径
|
72 |
"""
|
@@ -81,25 +69,24 @@ class TestStats:
|
|
81 |
"total_duration": sum(r.duration for r in self.results)
|
82 |
}
|
83 |
}
|
84 |
-
|
85 |
with open(path, "w", encoding="utf-8") as f:
|
86 |
json.dump(results_data, f, ensure_ascii=False, indent=2)
|
87 |
print(f"\n测试结果已保存到: {path}")
|
88 |
-
|
89 |
def print_summary(self):
|
90 |
-
"""打印测试统计摘要"""
|
91 |
total = len(self.results)
|
92 |
passed = sum(1 for r in self.results if r.success)
|
93 |
failed = total - passed
|
94 |
duration = sum(r.duration for r in self.results)
|
95 |
-
|
96 |
print("\n=== 测试结果摘要 ===")
|
97 |
print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
98 |
print(f"总用时: {duration:.2f}秒")
|
99 |
print(f"总计: {total} 个测试")
|
100 |
print(f"通过: {passed} 个")
|
101 |
print(f"失败: {failed} 个")
|
102 |
-
|
103 |
if failed > 0:
|
104 |
print("\n失败的测试:")
|
105 |
for result in self.results:
|
@@ -125,15 +112,15 @@ DEFAULT_CONFIG = {
|
|
125 |
|
126 |
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
|
127 |
"""发送 HTTP 请求,支持重试机制
|
128 |
-
|
129 |
Args:
|
130 |
url: 请求 URL
|
131 |
data: 请求数据
|
132 |
stream: 是否使用流式响应
|
133 |
-
|
134 |
Returns:
|
135 |
-
requests.Response 对象
|
136 |
-
|
137 |
Raises:
|
138 |
requests.exceptions.RequestException: 请求失败且重试次数用完
|
139 |
"""
|
@@ -141,7 +128,7 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
|
|
141 |
max_retries = server_config["max_retries"]
|
142 |
retry_delay = server_config["retry_delay"]
|
143 |
timeout = server_config["timeout"]
|
144 |
-
|
145 |
for attempt in range(max_retries):
|
146 |
try:
|
147 |
response = requests.post(
|
@@ -159,10 +146,10 @@ def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> reques
|
|
159 |
|
160 |
def load_config() -> Dict[str, Any]:
|
161 |
"""加载配置文件
|
162 |
-
|
163 |
首先尝试从当前目录的 config.json 加载,
|
164 |
如果不存在则使用默认配置
|
165 |
-
|
166 |
Returns:
|
167 |
配置字典
|
168 |
"""
|
@@ -174,7 +161,7 @@ def load_config() -> Dict[str, Any]:
|
|
174 |
|
175 |
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
176 |
"""格式化打印 JSON 响应数据
|
177 |
-
|
178 |
Args:
|
179 |
data: 要打印的数据字典
|
180 |
title: 打印的标题
|
@@ -199,12 +186,12 @@ def create_request_data(
|
|
199 |
model: str = None
|
200 |
) -> Dict[str, Any]:
|
201 |
"""创建基本的请求数据
|
202 |
-
|
203 |
Args:
|
204 |
content: 用户消息内容
|
205 |
stream: 是否使用流式响应
|
206 |
model: 模型名称
|
207 |
-
|
208 |
Returns:
|
209 |
包含完整请求数据的字典
|
210 |
"""
|
@@ -224,7 +211,7 @@ STATS = TestStats()
|
|
224 |
|
225 |
def run_test(func: Callable, name: str) -> None:
|
226 |
"""运行测试并记录结果
|
227 |
-
|
228 |
Args:
|
229 |
func: 测试函数
|
230 |
name: 测试名称
|
@@ -246,21 +233,21 @@ def test_non_stream_chat():
|
|
246 |
CONFIG["test_cases"]["basic"]["query"],
|
247 |
stream=False
|
248 |
)
|
249 |
-
|
250 |
# 发送请求
|
251 |
response = make_request(url, data)
|
252 |
-
|
253 |
# 打印响应
|
254 |
if OutputControl.is_verbose():
|
255 |
print("\n=== 非流式调用响应 ===")
|
256 |
response_json = response.json()
|
257 |
-
|
258 |
# 打印响应内容
|
259 |
print_json_response({
|
260 |
"model": response_json["model"],
|
261 |
"message": response_json["message"]
|
262 |
}, "响应内容")
|
263 |
-
|
264 |
# # 打印性能统计
|
265 |
# print_json_response({
|
266 |
# "total_duration": response_json["total_duration"],
|
@@ -273,7 +260,7 @@ def test_non_stream_chat():
|
|
273 |
|
274 |
def test_stream_chat():
|
275 |
"""测试流式调用 /api/chat 接口
|
276 |
-
|
277 |
使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
|
278 |
响应格式:
|
279 |
{
|
@@ -286,7 +273,7 @@ def test_stream_chat():
|
|
286 |
},
|
287 |
"done": false
|
288 |
}
|
289 |
-
|
290 |
最后一条消息会包含性能统计信息,done 为 true。
|
291 |
"""
|
292 |
url = get_base_url()
|
@@ -294,10 +281,10 @@ def test_stream_chat():
|
|
294 |
CONFIG["test_cases"]["basic"]["query"],
|
295 |
stream=True
|
296 |
)
|
297 |
-
|
298 |
# 发送请求并获取流式响应
|
299 |
response = make_request(url, data, stream=True)
|
300 |
-
|
301 |
if OutputControl.is_verbose():
|
302 |
print("\n=== 流式调用响应 ===")
|
303 |
output_buffer = []
|
@@ -321,24 +308,24 @@ def test_stream_chat():
|
|
321 |
print("Error decoding JSON from response line")
|
322 |
finally:
|
323 |
response.close() # 确保关闭响应连接
|
324 |
-
|
325 |
# 打印一个换行
|
326 |
print()
|
327 |
|
328 |
def test_query_modes():
|
329 |
"""测试不同的查询模式前缀
|
330 |
-
|
331 |
支持的查询模式:
|
332 |
- /local: 本地检索模式,只在相关度高的文档中搜索
|
333 |
- /global: 全局检索模式,在所有文档中搜索
|
334 |
- /naive: 朴素模式,不使用任何优化策略
|
335 |
- /hybrid: 混合模式(默认),结合多种策略
|
336 |
-
|
337 |
每个模式都会返回相同格式的响应,但检索策略不同。
|
338 |
"""
|
339 |
url = get_base_url()
|
340 |
modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式
|
341 |
-
|
342 |
for mode in modes:
|
343 |
if OutputControl.is_verbose():
|
344 |
print(f"\n=== 测试 /{mode} 模式 ===")
|
@@ -346,11 +333,11 @@ def test_query_modes():
|
|
346 |
f"/{mode} {CONFIG['test_cases']['basic']['query']}",
|
347 |
stream=False
|
348 |
)
|
349 |
-
|
350 |
# 发送请求
|
351 |
response = make_request(url, data)
|
352 |
response_json = response.json()
|
353 |
-
|
354 |
# 打印响应内容
|
355 |
print_json_response({
|
356 |
"model": response_json["model"],
|
@@ -359,13 +346,13 @@ def test_query_modes():
|
|
359 |
|
360 |
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
361 |
"""创建用于错误测试的请求数据
|
362 |
-
|
363 |
Args:
|
364 |
error_type: 错误类型,支持:
|
365 |
- empty_messages: 空消息列表
|
366 |
- invalid_role: 无效的角色字段
|
367 |
- missing_content: 缺少内容字段
|
368 |
-
|
369 |
Returns:
|
370 |
包含错误数据的请求字典
|
371 |
"""
|
@@ -399,19 +386,19 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
|
399 |
|
400 |
def test_stream_error_handling():
|
401 |
"""测试流式响应的错误处理
|
402 |
-
|
403 |
测试场景:
|
404 |
1. 空消息列表
|
405 |
2. 消息格式错误(缺少必需字段)
|
406 |
-
|
407 |
错误响应会立即返回,不会建立流式连接。
|
408 |
状态码应该是 4xx,并返回详细的错误信息。
|
409 |
"""
|
410 |
url = get_base_url()
|
411 |
-
|
412 |
if OutputControl.is_verbose():
|
413 |
print("\n=== 测试流式响应错误处理 ===")
|
414 |
-
|
415 |
# 测试空消息列表
|
416 |
if OutputControl.is_verbose():
|
417 |
print("\n--- 测试空消息列表(流式)---")
|
@@ -421,7 +408,7 @@ def test_stream_error_handling():
|
|
421 |
if response.status_code != 200:
|
422 |
print_json_response(response.json(), "错误信息")
|
423 |
response.close()
|
424 |
-
|
425 |
# 测试无效角色字段
|
426 |
if OutputControl.is_verbose():
|
427 |
print("\n--- 测试无效角色字段(流式)---")
|
@@ -444,23 +431,23 @@ def test_stream_error_handling():
|
|
444 |
|
445 |
def test_error_handling():
|
446 |
"""测试非流式响应的错误处理
|
447 |
-
|
448 |
测试场景:
|
449 |
1. 空消息列表
|
450 |
2. 消息格式错误(缺少必需字段)
|
451 |
-
|
452 |
错误响应格式:
|
453 |
{
|
454 |
"detail": "错误描述"
|
455 |
}
|
456 |
-
|
457 |
所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
|
458 |
"""
|
459 |
url = get_base_url()
|
460 |
-
|
461 |
if OutputControl.is_verbose():
|
462 |
print("\n=== 测试错误处理 ===")
|
463 |
-
|
464 |
# 测试空消息列表
|
465 |
if OutputControl.is_verbose():
|
466 |
print("\n--- 测试空消息列表 ---")
|
@@ -469,7 +456,7 @@ def test_error_handling():
|
|
469 |
response = make_request(url, data)
|
470 |
print(f"状态码: {response.status_code}")
|
471 |
print_json_response(response.json(), "错误信息")
|
472 |
-
|
473 |
# 测试无效角色字段
|
474 |
if OutputControl.is_verbose():
|
475 |
print("\n--- 测试无效���色字段 ---")
|
@@ -490,7 +477,7 @@ def test_error_handling():
|
|
490 |
|
491 |
def get_test_cases() -> Dict[str, Callable]:
|
492 |
"""获取所有可用的测试用例
|
493 |
-
|
494 |
Returns:
|
495 |
测试名称到测试函数的映射字典
|
496 |
"""
|
@@ -564,21 +551,21 @@ def parse_args() -> argparse.Namespace:
|
|
564 |
|
565 |
if __name__ == "__main__":
|
566 |
args = parse_args()
|
567 |
-
|
568 |
# 设置输出模式
|
569 |
OutputControl.set_verbose(not args.quiet)
|
570 |
-
|
571 |
# 如果指定了查询内容,更新配置
|
572 |
if args.ask:
|
573 |
CONFIG["test_cases"]["basic"]["query"] = args.ask
|
574 |
-
|
575 |
# 如果指定了创建配置文件
|
576 |
if args.init_config:
|
577 |
create_default_config()
|
578 |
exit(0)
|
579 |
-
|
580 |
test_cases = get_test_cases()
|
581 |
-
|
582 |
try:
|
583 |
if "all" in args.tests:
|
584 |
# 运行所有测试
|
@@ -586,11 +573,11 @@ if __name__ == "__main__":
|
|
586 |
print("\n【基本功能测试】")
|
587 |
run_test(test_non_stream_chat, "非流式调用测试")
|
588 |
run_test(test_stream_chat, "流式调用测试")
|
589 |
-
|
590 |
if OutputControl.is_verbose():
|
591 |
print("\n【查询模式测试】")
|
592 |
run_test(test_query_modes, "查询模式测试")
|
593 |
-
|
594 |
if OutputControl.is_verbose():
|
595 |
print("\n【错误处理测试】")
|
596 |
run_test(test_error_handling, "错误处理测试")
|
|
|
1 |
"""
|
2 |
+
LightRAG Ollama Compatibility Interface Test Script
|
3 |
|
4 |
+
This script tests the LightRAG's Ollama compatibility interface, including:
|
5 |
+
1. Basic functionality tests (streaming and non-streaming responses)
|
6 |
+
2. Query mode tests (local, global, naive, hybrid)
|
7 |
+
3. Error handling tests (including streaming and non-streaming scenarios)
|
8 |
|
9 |
+
All responses use the JSON Lines format, complying with the Ollama API specification.
|
10 |
"""
|
11 |
|
12 |
import requests
|
|
|
24 |
|
25 |
@classmethod
|
26 |
def set_verbose(cls, verbose: bool) -> None:
|
|
|
|
|
|
|
|
|
|
|
27 |
cls._verbose = verbose
|
28 |
|
29 |
@classmethod
|
30 |
def is_verbose(cls) -> bool:
|
|
|
|
|
|
|
|
|
|
|
31 |
return cls._verbose
|
32 |
|
33 |
@dataclass
|
|
|
38 |
duration: float
|
39 |
error: Optional[str] = None
|
40 |
timestamp: str = ""
|
41 |
+
|
42 |
def __post_init__(self):
|
|
|
43 |
if not self.timestamp:
|
44 |
self.timestamp = datetime.now().isoformat()
|
45 |
|
|
|
48 |
def __init__(self):
|
49 |
self.results: List[TestResult] = []
|
50 |
self.start_time = datetime.now()
|
51 |
+
|
52 |
def add_result(self, result: TestResult):
|
|
|
53 |
self.results.append(result)
|
54 |
+
|
55 |
def export_results(self, path: str = "test_results.json"):
|
56 |
"""导出测试结果到 JSON 文件
|
57 |
+
|
58 |
Args:
|
59 |
path: 输出文件路径
|
60 |
"""
|
|
|
69 |
"total_duration": sum(r.duration for r in self.results)
|
70 |
}
|
71 |
}
|
72 |
+
|
73 |
with open(path, "w", encoding="utf-8") as f:
|
74 |
json.dump(results_data, f, ensure_ascii=False, indent=2)
|
75 |
print(f"\n测试结果已保存到: {path}")
|
76 |
+
|
77 |
def print_summary(self):
|
|
|
78 |
total = len(self.results)
|
79 |
passed = sum(1 for r in self.results if r.success)
|
80 |
failed = total - passed
|
81 |
duration = sum(r.duration for r in self.results)
|
82 |
+
|
83 |
print("\n=== 测试结果摘要 ===")
|
84 |
print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
|
85 |
print(f"总用时: {duration:.2f}秒")
|
86 |
print(f"总计: {total} 个测试")
|
87 |
print(f"通过: {passed} 个")
|
88 |
print(f"失败: {failed} 个")
|
89 |
+
|
90 |
if failed > 0:
|
91 |
print("\n失败的测试:")
|
92 |
for result in self.results:
|
|
|
112 |
|
113 |
def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
|
114 |
"""发送 HTTP 请求,支持重试机制
|
115 |
+
|
116 |
Args:
|
117 |
url: 请求 URL
|
118 |
data: 请求数据
|
119 |
stream: 是否使用流式响应
|
120 |
+
|
121 |
Returns:
|
122 |
+
requests.Response: 对象
|
123 |
+
|
124 |
Raises:
|
125 |
requests.exceptions.RequestException: 请求失败且重试次数用完
|
126 |
"""
|
|
|
128 |
max_retries = server_config["max_retries"]
|
129 |
retry_delay = server_config["retry_delay"]
|
130 |
timeout = server_config["timeout"]
|
131 |
+
|
132 |
for attempt in range(max_retries):
|
133 |
try:
|
134 |
response = requests.post(
|
|
|
146 |
|
147 |
def load_config() -> Dict[str, Any]:
|
148 |
"""加载配置文件
|
149 |
+
|
150 |
首先尝试从当前目录的 config.json 加载,
|
151 |
如果不存在则使用默认配置
|
152 |
+
|
153 |
Returns:
|
154 |
配置字典
|
155 |
"""
|
|
|
161 |
|
162 |
def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
|
163 |
"""格式化打印 JSON 响应数据
|
164 |
+
|
165 |
Args:
|
166 |
data: 要打印的数据字典
|
167 |
title: 打印的标题
|
|
|
186 |
model: str = None
|
187 |
) -> Dict[str, Any]:
|
188 |
"""创建基本的请求数据
|
189 |
+
|
190 |
Args:
|
191 |
content: 用户消息内容
|
192 |
stream: 是否使用流式响应
|
193 |
model: 模型名称
|
194 |
+
|
195 |
Returns:
|
196 |
包含完整请求数据的字典
|
197 |
"""
|
|
|
211 |
|
212 |
def run_test(func: Callable, name: str) -> None:
|
213 |
"""运行测试并记录结果
|
214 |
+
|
215 |
Args:
|
216 |
func: 测试函数
|
217 |
name: 测试名称
|
|
|
233 |
CONFIG["test_cases"]["basic"]["query"],
|
234 |
stream=False
|
235 |
)
|
236 |
+
|
237 |
# 发送请求
|
238 |
response = make_request(url, data)
|
239 |
+
|
240 |
# 打印响应
|
241 |
if OutputControl.is_verbose():
|
242 |
print("\n=== 非流式调用响应 ===")
|
243 |
response_json = response.json()
|
244 |
+
|
245 |
# 打印响应内容
|
246 |
print_json_response({
|
247 |
"model": response_json["model"],
|
248 |
"message": response_json["message"]
|
249 |
}, "响应内容")
|
250 |
+
|
251 |
# # 打印性能统计
|
252 |
# print_json_response({
|
253 |
# "total_duration": response_json["total_duration"],
|
|
|
260 |
|
261 |
def test_stream_chat():
|
262 |
"""测试流式调用 /api/chat 接口
|
263 |
+
|
264 |
使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
|
265 |
响应格式:
|
266 |
{
|
|
|
273 |
},
|
274 |
"done": false
|
275 |
}
|
276 |
+
|
277 |
最后一条消息会包含性能统计信息,done 为 true。
|
278 |
"""
|
279 |
url = get_base_url()
|
|
|
281 |
CONFIG["test_cases"]["basic"]["query"],
|
282 |
stream=True
|
283 |
)
|
284 |
+
|
285 |
# 发送请求并获取流式响应
|
286 |
response = make_request(url, data, stream=True)
|
287 |
+
|
288 |
if OutputControl.is_verbose():
|
289 |
print("\n=== 流式调用响应 ===")
|
290 |
output_buffer = []
|
|
|
308 |
print("Error decoding JSON from response line")
|
309 |
finally:
|
310 |
response.close() # 确保关闭响应连接
|
311 |
+
|
312 |
# 打印一个换行
|
313 |
print()
|
314 |
|
315 |
def test_query_modes():
|
316 |
"""测试不同的查询模式前缀
|
317 |
+
|
318 |
支持的查询模式:
|
319 |
- /local: 本地检索模式,只在相关度高的文档中搜索
|
320 |
- /global: 全局检索模式,在所有文档中搜索
|
321 |
- /naive: 朴素模式,不使用任何优化策略
|
322 |
- /hybrid: 混合模式(默认),结合多种策略
|
323 |
+
|
324 |
每个模式都会返回相同格式的响应,但检索策略不同。
|
325 |
"""
|
326 |
url = get_base_url()
|
327 |
modes = ["local", "global", "naive", "hybrid", "mix"] # 支持的查询模式
|
328 |
+
|
329 |
for mode in modes:
|
330 |
if OutputControl.is_verbose():
|
331 |
print(f"\n=== 测试 /{mode} 模式 ===")
|
|
|
333 |
f"/{mode} {CONFIG['test_cases']['basic']['query']}",
|
334 |
stream=False
|
335 |
)
|
336 |
+
|
337 |
# 发送请求
|
338 |
response = make_request(url, data)
|
339 |
response_json = response.json()
|
340 |
+
|
341 |
# 打印响应内容
|
342 |
print_json_response({
|
343 |
"model": response_json["model"],
|
|
|
346 |
|
347 |
def create_error_test_data(error_type: str) -> Dict[str, Any]:
|
348 |
"""创建用于错误测试的请求数据
|
349 |
+
|
350 |
Args:
|
351 |
error_type: 错误类型,支持:
|
352 |
- empty_messages: 空消息列表
|
353 |
- invalid_role: 无效的角色字段
|
354 |
- missing_content: 缺少内容字段
|
355 |
+
|
356 |
Returns:
|
357 |
包含错误数据的请求字典
|
358 |
"""
|
|
|
386 |
|
387 |
def test_stream_error_handling():
|
388 |
"""测试流式响应的错误处理
|
389 |
+
|
390 |
测试场景:
|
391 |
1. 空消息列表
|
392 |
2. 消息格式错误(缺少必需字段)
|
393 |
+
|
394 |
错误响应会立即返回,不会建立流式连接。
|
395 |
状态码应该是 4xx,并返回详细的错误信息。
|
396 |
"""
|
397 |
url = get_base_url()
|
398 |
+
|
399 |
if OutputControl.is_verbose():
|
400 |
print("\n=== 测试流式响应错误处理 ===")
|
401 |
+
|
402 |
# 测试空消息列表
|
403 |
if OutputControl.is_verbose():
|
404 |
print("\n--- 测试空消息列表(流式)---")
|
|
|
408 |
if response.status_code != 200:
|
409 |
print_json_response(response.json(), "错误信息")
|
410 |
response.close()
|
411 |
+
|
412 |
# 测试无效角色字段
|
413 |
if OutputControl.is_verbose():
|
414 |
print("\n--- 测试无效角色字段(流式)---")
|
|
|
431 |
|
432 |
def test_error_handling():
|
433 |
"""测试非流式响应的错误处理
|
434 |
+
|
435 |
测试场景:
|
436 |
1. 空消息列表
|
437 |
2. 消息格式错误(缺少必需字段)
|
438 |
+
|
439 |
错误响应格式:
|
440 |
{
|
441 |
"detail": "错误描述"
|
442 |
}
|
443 |
+
|
444 |
所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
|
445 |
"""
|
446 |
url = get_base_url()
|
447 |
+
|
448 |
if OutputControl.is_verbose():
|
449 |
print("\n=== 测试错误处理 ===")
|
450 |
+
|
451 |
# 测试空消息列表
|
452 |
if OutputControl.is_verbose():
|
453 |
print("\n--- 测试空消息列表 ---")
|
|
|
456 |
response = make_request(url, data)
|
457 |
print(f"状态码: {response.status_code}")
|
458 |
print_json_response(response.json(), "错误信息")
|
459 |
+
|
460 |
# 测试无效角色字段
|
461 |
if OutputControl.is_verbose():
|
462 |
print("\n--- 测试无效���色字段 ---")
|
|
|
477 |
|
478 |
def get_test_cases() -> Dict[str, Callable]:
|
479 |
"""获取所有可用的测试用例
|
480 |
+
|
481 |
Returns:
|
482 |
测试名称到测试函数的映射字典
|
483 |
"""
|
|
|
551 |
|
552 |
if __name__ == "__main__":
|
553 |
args = parse_args()
|
554 |
+
|
555 |
# 设置输出模式
|
556 |
OutputControl.set_verbose(not args.quiet)
|
557 |
+
|
558 |
# 如果指定了查询内容,更新配置
|
559 |
if args.ask:
|
560 |
CONFIG["test_cases"]["basic"]["query"] = args.ask
|
561 |
+
|
562 |
# 如果指定了创建配置文件
|
563 |
if args.init_config:
|
564 |
create_default_config()
|
565 |
exit(0)
|
566 |
+
|
567 |
test_cases = get_test_cases()
|
568 |
+
|
569 |
try:
|
570 |
if "all" in args.tests:
|
571 |
# 运行所有测试
|
|
|
573 |
print("\n【基本功能测试】")
|
574 |
run_test(test_non_stream_chat, "非流式调用测试")
|
575 |
run_test(test_stream_chat, "流式调用测试")
|
576 |
+
|
577 |
if OutputControl.is_verbose():
|
578 |
print("\n【查询模式测试】")
|
579 |
run_test(test_query_modes, "查询模式测试")
|
580 |
+
|
581 |
if OutputControl.is_verbose():
|
582 |
print("\n【错误处理测试】")
|
583 |
run_test(test_error_handling, "错误处理测试")
|