File size: 18,594 Bytes
9f80fd7
99a3d9e
9f80fd7
99a3d9e
 
 
 
9f80fd7
99a3d9e
9f80fd7
 
fda27b8
 
9f80fd7
 
2c2fa06
9f80fd7
 
 
fda27b8
9f80fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
99a3d9e
9f80fd7
 
99a3d9e
9f80fd7
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
99a3d9e
9f80fd7
 
 
99a3d9e
9f80fd7
 
 
 
 
99a3d9e
9f80fd7
 
 
 
 
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
 
 
fda27b8
9f80fd7
 
 
 
 
 
835a76d
9f80fd7
 
 
 
 
 
99a3d9e
9f80fd7
 
 
 
99a3d9e
9f80fd7
99a3d9e
 
9f80fd7
 
 
 
 
 
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99a3d9e
9f80fd7
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
 
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99a3d9e
9f80fd7
 
 
 
99a3d9e
9f80fd7
 
 
 
 
fda27b8
 
 
9f80fd7
fda27b8
 
9f80fd7
fda27b8
9f80fd7
 
 
 
 
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99a3d9e
fda27b8
9f80fd7
99a3d9e
fda27b8
9f80fd7
 
f55e12d
99a3d9e
9f80fd7
 
f55e12d
 
9f80fd7
99a3d9e
2b61630
 
 
 
 
 
 
 
 
fda27b8
 
9f80fd7
99a3d9e
9f80fd7
 
 
fda27b8
9f80fd7
 
 
 
 
 
 
fda27b8
99a3d9e
9f80fd7
 
 
 
835a76d
9f80fd7
 
99a3d9e
9f80fd7
 
99a3d9e
9f80fd7
 
fda27b8
 
9f80fd7
 
 
 
 
 
 
2b61630
9f80fd7
 
 
 
 
 
 
 
 
fda27b8
 
99a3d9e
9f80fd7
 
 
 
 
99a3d9e
9f80fd7
 
 
 
 
99a3d9e
9f80fd7
 
 
528d6fd
99a3d9e
9f80fd7
 
 
 
835a76d
9f80fd7
 
99a3d9e
9f80fd7
 
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
99a3d9e
9f80fd7
 
 
 
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99a3d9e
9f80fd7
 
 
99a3d9e
9f80fd7
 
 
 
99a3d9e
9f80fd7
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99a3d9e
9f80fd7
 
 
99a3d9e
9f80fd7
 
 
 
99a3d9e
9f80fd7
 
 
99a3d9e
9f80fd7
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
835a76d
 
 
 
 
 
 
 
9f80fd7
 
 
 
 
 
 
 
 
9b0d64b
 
9f80fd7
 
835a76d
 
 
 
 
9f80fd7
 
fda27b8
 
9f80fd7
99a3d9e
9f80fd7
 
99a3d9e
835a76d
 
 
99a3d9e
9f80fd7
 
 
 
99a3d9e
9f80fd7
99a3d9e
9f80fd7
 
 
 
 
 
 
99a3d9e
9f80fd7
 
 
99a3d9e
9f80fd7
 
 
 
 
 
 
 
 
 
 
 
 
9b0d64b
9f80fd7
9b0d64b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
"""
LightRAG Ollama Compatibility Interface Test Script

This script tests the LightRAG's Ollama compatibility interface, including:
1. Basic functionality tests (streaming and non-streaming responses)
2. Query mode tests (local, global, naive, hybrid)
3. Error handling tests (including streaming and non-streaming scenarios)

All responses use the JSON Lines format, complying with the Ollama API specification.
"""

import requests
import json
import argparse
import time
from typing import Dict, Any, Optional, List, Callable
from dataclasses import dataclass, asdict
from datetime import datetime
from pathlib import Path

class OutputControl:
    """输出控制类,管理测试输出的详细程度"""
    _verbose: bool = False

    @classmethod
    def set_verbose(cls, verbose: bool) -> None:
        cls._verbose = verbose

    @classmethod
    def is_verbose(cls) -> bool:
        return cls._verbose

@dataclass
class TestResult:
    """测试结果数据类"""
    name: str
    success: bool
    duration: float
    error: Optional[str] = None
    timestamp: str = ""

    def __post_init__(self):
        if not self.timestamp:
            self.timestamp = datetime.now().isoformat()

class TestStats:
    """测试统计信息"""
    def __init__(self):
        self.results: List[TestResult] = []
        self.start_time = datetime.now()

    def add_result(self, result: TestResult):
        self.results.append(result)

    def export_results(self, path: str = "test_results.json"):
        """导出测试结果到 JSON 文件

        Args:
            path: 输出文件路径
        """
        results_data = {
            "start_time": self.start_time.isoformat(),
            "end_time": datetime.now().isoformat(),
            "results": [asdict(r) for r in self.results],
            "summary": {
                "total": len(self.results),
                "passed": sum(1 for r in self.results if r.success),
                "failed": sum(1 for r in self.results if not r.success),
                "total_duration": sum(r.duration for r in self.results)
            }
        }

        with open(path, "w", encoding="utf-8") as f:
            json.dump(results_data, f, ensure_ascii=False, indent=2)
        print(f"\n测试结果已保存到: {path}")

    def print_summary(self):
        total = len(self.results)
        passed = sum(1 for r in self.results if r.success)
        failed = total - passed
        duration = sum(r.duration for r in self.results)

        print("\n=== 测试结果摘要 ===")
        print(f"开始时间: {self.start_time.strftime('%Y-%m-%d %H:%M:%S')}")
        print(f"总用时: {duration:.2f}秒")
        print(f"总计: {total} 个测试")
        print(f"通过: {passed} 个")
        print(f"失败: {failed} 个")

        if failed > 0:
            print("\n失败的测试:")
            for result in self.results:
                if not result.success:
                    print(f"- {result.name}: {result.error}")

# 默认配置
DEFAULT_CONFIG = {
    "server": {
        "host": "localhost",
        "port": 9621,
        "model": "lightrag:latest",
        "timeout": 30,  # 请求超时时间(秒)
        "max_retries": 3,  # 最大重试次数
        "retry_delay": 1  # 重试间隔(秒)
    },
    "test_cases": {
        "basic": {
            "query": "唐僧有几个徒弟"
        }
    }
}

def make_request(url: str, data: Dict[str, Any], stream: bool = False) -> requests.Response:
    """发送 HTTP 请求,支持重试机制

    Args:
        url: 请求 URL
        data: 请求数据
        stream: 是否使用流式响应

    Returns:
        requests.Response: 对象

    Raises:
        requests.exceptions.RequestException: 请求失败且重试次数用完
    """
    server_config = CONFIG["server"]
    max_retries = server_config["max_retries"]
    retry_delay = server_config["retry_delay"]
    timeout = server_config["timeout"]

    for attempt in range(max_retries):
        try:
            response = requests.post(
                url,
                json=data,
                stream=stream,
                timeout=timeout
            )
            return response
        except requests.exceptions.RequestException as e:
            if attempt == max_retries - 1:  # 最后一次重试
                raise
            print(f"\n请求失败,{retry_delay}秒后重试: {str(e)}")
            time.sleep(retry_delay)

def load_config() -> Dict[str, Any]:
    """加载配置文件

    首先尝试从当前目录的 config.json 加载,
    如果不存在则使用默认配置

    Returns:
        配置字典
    """
    config_path = Path("config.json")
    if config_path.exists():
        with open(config_path, "r", encoding="utf-8") as f:
            return json.load(f)
    return DEFAULT_CONFIG

def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2) -> None:
    """格式化打印 JSON 响应数据

    Args:
        data: 要打印的数据字典
        title: 打印的标题
        indent: JSON 缩进空格数
    """
    if OutputControl.is_verbose():
        if title:
            print(f"\n=== {title} ===")
        print(json.dumps(data, ensure_ascii=False, indent=indent))

# 全局配置
CONFIG = load_config()

def get_base_url() -> str:
    """返回基础 URL"""
    server = CONFIG["server"]
    return f"http://{server['host']}:{server['port']}/api/chat"

def create_request_data(
    content: str,
    stream: bool = False,
    model: str = None
) -> Dict[str, Any]:
    """创建基本的请求数据

    Args:
        content: 用户消息内容
        stream: 是否使用流式响应
        model: 模型名称

    Returns:
        包含完整请求数据的字典
    """
    return {
        "model": model or CONFIG["server"]["model"],
        "messages": [
            {
                "role": "user",
                "content": content
            }
        ],
        "stream": stream
    }

# 全局测试统计
STATS = TestStats()

def run_test(func: Callable, name: str) -> None:
    """运行测试并记录结果

    Args:
        func: 测试函数
        name: 测试名称
    """
    start_time = time.time()
    try:
        func()
        duration = time.time() - start_time
        STATS.add_result(TestResult(name, True, duration))
    except Exception as e:
        duration = time.time() - start_time
        STATS.add_result(TestResult(name, False, duration, str(e)))
        raise

def test_non_stream_chat():
    """测试非流式调用 /api/chat 接口"""
    url = get_base_url()
    data = create_request_data(
        CONFIG["test_cases"]["basic"]["query"],
        stream=False
    )

    # 发送请求
    response = make_request(url, data)

    # 打印响应
    if OutputControl.is_verbose():
        print("\n=== 非流式调用响应 ===")
    response_json = response.json()

    # 打印响应内容
    print_json_response({
        "model": response_json["model"],
        "message": response_json["message"]
    }, "响应内容")

    # # 打印性能统计
    # print_json_response({
    #     "total_duration": response_json["total_duration"],
    #     "load_duration": response_json["load_duration"],
    #     "prompt_eval_count": response_json["prompt_eval_count"],
    #     "prompt_eval_duration": response_json["prompt_eval_duration"],
    #     "eval_count": response_json["eval_count"],
    #     "eval_duration": response_json["eval_duration"]
    # }, "性能统计")

def test_stream_chat():
    """测试流式调用 /api/chat 接口

    使用 JSON Lines 格式处理流式响应,每行是一个完整的 JSON 对象。
    响应格式:
    {
        "model": "lightrag:latest",
        "created_at": "2024-01-15T00:00:00Z",
        "message": {
            "role": "assistant",
            "content": "部分响应内容",
            "images": null
        },
        "done": false
    }

    最后一条消息会包含性能统计信息,done 为 true。
    """
    url = get_base_url()
    data = create_request_data(
        CONFIG["test_cases"]["basic"]["query"],
        stream=True
    )

    # 发送请求并获取流式响应
    response = make_request(url, data, stream=True)

    if OutputControl.is_verbose():
        print("\n=== 流式调用响应 ===")
    output_buffer = []
    try:
        for line in response.iter_lines():
            if line:  # 跳过空行
                try:
                    # 解码并解析 JSON
                    data = json.loads(line.decode('utf-8'))
                    if data.get("done", True):  # 如果是完成标记
                        if "total_duration" in data:  # 最终的性能统计消息
                            # print_json_response(data, "性能统计")
                            break
                    else:  # 正常的内容消息
                        message = data.get("message", {})
                        content = message.get("content", "")
                        if content:  # 只收集非空内容
                            output_buffer.append(content)
                            print(content, end="", flush=True)  # 实时打印内容
                except json.JSONDecodeError:
                    print("Error decoding JSON from response line")
    finally:
        response.close()  # 确保关闭响应连接

    # 打印一个换行
    print()

def test_query_modes():
    """测试不同的查询模式前缀

    支持的查询模式:
    - /local: 本地检索模式,只在相关度高的文档中搜索
    - /global: 全局检索模式,在所有文档中搜索
    - /naive: 朴素模式,不使用任何优化策略
    - /hybrid: 混合模式(默认),结合多种策略

    每个模式都会返回相同格式的响应,但检索策略不同。
    """
    url = get_base_url()
    modes = ["local", "global", "naive", "hybrid", "mix"]  # 支持的查询模式

    for mode in modes:
        if OutputControl.is_verbose():
            print(f"\n=== 测试 /{mode} 模式 ===")
        data = create_request_data(
            f"/{mode} {CONFIG['test_cases']['basic']['query']}",
            stream=False
        )

        # 发送请求
        response = make_request(url, data)
        response_json = response.json()

        # 打印响应内容
        print_json_response({
            "model": response_json["model"],
            "message": response_json["message"]
        })

def create_error_test_data(error_type: str) -> Dict[str, Any]:
    """创建用于错误测试的请求数据

    Args:
        error_type: 错误类型,支持:
            - empty_messages: 空消息列表
            - invalid_role: 无效的角色字段
            - missing_content: 缺少内容字段

    Returns:
        包含错误数据的请求字典
    """
    error_data = {
        "empty_messages": {
            "model": "lightrag:latest",
            "messages": [],
            "stream": True
        },
        "invalid_role": {
            "model": "lightrag:latest",
            "messages": [
                {
                    "invalid_role": "user",
                    "content": "测试消息"
                }
            ],
            "stream": True
        },
        "missing_content": {
            "model": "lightrag:latest",
            "messages": [
                {
                    "role": "user"
                }
            ],
            "stream": True
        }
    }
    return error_data.get(error_type, error_data["empty_messages"])

def test_stream_error_handling():
    """测试流式响应的错误处理

    测试场景:
    1. 空消息列表
    2. 消息格式错误(缺少必需字段)

    错误响应会立即返回,不会建立流式连接。
    状态码应该是 4xx,并返回详细的错误信息。
    """
    url = get_base_url()

    if OutputControl.is_verbose():
        print("\n=== 测试流式响应错误处理 ===")

    # 测试空消息列表
    if OutputControl.is_verbose():
        print("\n--- 测试空消息列表(流式)---")
    data = create_error_test_data("empty_messages")
    response = make_request(url, data, stream=True)
    print(f"状态码: {response.status_code}")
    if response.status_code != 200:
        print_json_response(response.json(), "错误信息")
    response.close()

    # 测试无效角色字段
    if OutputControl.is_verbose():
        print("\n--- 测试无效角色字段(流式)---")
    data = create_error_test_data("invalid_role")
    response = make_request(url, data, stream=True)
    print(f"状态码: {response.status_code}")
    if response.status_code != 200:
        print_json_response(response.json(), "错误信息")
    response.close()

    # 测试缺少内容字段
    if OutputControl.is_verbose():
        print("\n--- 测试缺少内容字段(流式)---")
    data = create_error_test_data("missing_content")
    response = make_request(url, data, stream=True)
    print(f"状态码: {response.status_code}")
    if response.status_code != 200:
        print_json_response(response.json(), "错误信息")
    response.close()

def test_error_handling():
    """测试非流式响应的错误处理

    测试场景:
    1. 空消息列表
    2. 消息格式错误(缺少必需字段)

    错误响应格式:
    {
        "detail": "错误描述"
    }

    所有错误都应该返回合适的 HTTP 状态码和清晰的错误信息。
    """
    url = get_base_url()

    if OutputControl.is_verbose():
        print("\n=== 测试错误处理 ===")

    # 测试空消息列表
    if OutputControl.is_verbose():
        print("\n--- 测试空消息列表 ---")
    data = create_error_test_data("empty_messages")
    data["stream"] = False  # 修改为非流式模式
    response = make_request(url, data)
    print(f"状态码: {response.status_code}")
    print_json_response(response.json(), "错误信息")

    # 测试无效角色字段
    if OutputControl.is_verbose():
        print("\n--- 测试无效角色字段 ---")
    data = create_error_test_data("invalid_role")
    data["stream"] = False  # 修改为非流式模式
    response = make_request(url, data)
    print(f"状态码: {response.status_code}")
    print_json_response(response.json(), "错误信息")

    # 测试缺少内容字段
    if OutputControl.is_verbose():
        print("\n--- 测试缺少内容字段 ---")
    data = create_error_test_data("missing_content")
    data["stream"] = False  # 修改为非流式模式
    response = make_request(url, data)
    print(f"状态码: {response.status_code}")
    print_json_response(response.json(), "错误信息")

def get_test_cases() -> Dict[str, Callable]:
    """获取所有可用的测试用例

    Returns:
        测试名称到测试函数的映射字典
    """
    return {
        "non_stream": test_non_stream_chat,
        "stream": test_stream_chat,
        "modes": test_query_modes,
        "errors": test_error_handling,
        "stream_errors": test_stream_error_handling
    }

def create_default_config():
    """创建默认配置文件"""
    config_path = Path("config.json")
    if not config_path.exists():
        with open(config_path, "w", encoding="utf-8") as f:
            json.dump(DEFAULT_CONFIG, f, ensure_ascii=False, indent=2)
        print(f"已创建默认配置文件: {config_path}")

def parse_args() -> argparse.Namespace:
    """解析命令行参数"""
    parser = argparse.ArgumentParser(
        description="LightRAG Ollama 兼容接口测试",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
配置文件 (config.json):
  {
    "server": {
      "host": "localhost",      # 服务器地址
      "port": 9621,            # 服务器端口
      "model": "lightrag:latest" # 默认模型名称
    },
    "test_cases": {
      "basic": {
        "query": "测试查询",      # 基本查询文本
        "stream_query": "流式查询" # 流式查询文本
      }
    }
  }
"""
    )
    parser.add_argument(
        "-q", "--quiet",
        action="store_true",
        help="静默模式,只显示测试结果摘要"
    )
    parser.add_argument(
        "-a", "--ask",
        type=str,
        help="指定查询内容,会覆盖配置文件中的查询设置"
    )
    parser.add_argument(
        "--init-config",
        action="store_true",
        help="创建默认配置文件"
    )
    parser.add_argument(
        "--output",
        type=str,
        default="",
        help="测试结果输出文件路径,默认不输出到文件"
    )
    parser.add_argument(
        "--tests",
        nargs="+",
        choices=list(get_test_cases().keys()) + ["all"],
        default=["all"],
        help="要运行的测试用例,可选: %(choices)s。使用 all 运行所有测试"
    )
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()

    # 设置输出模式
    OutputControl.set_verbose(not args.quiet)

    # 如果指定了查询内容,更新配置
    if args.ask:
        CONFIG["test_cases"]["basic"]["query"] = args.ask

    # 如果指定了创建配置文件
    if args.init_config:
        create_default_config()
        exit(0)

    test_cases = get_test_cases()

    try:
        if "all" in args.tests:
            # 运行所有测试
            if OutputControl.is_verbose():
                print("\n【基本功能测试】")
            run_test(test_non_stream_chat, "非流式调用测试")
            run_test(test_stream_chat, "流式调用测试")

            if OutputControl.is_verbose():
                print("\n【查询模式测试】")
            run_test(test_query_modes, "查询模式测试")

            if OutputControl.is_verbose():
                print("\n【错误处理测试】")
            run_test(test_error_handling, "错误处理测试")
            run_test(test_stream_error_handling, "流式错误处理测试")
        else:
            # 运行指定的测试
            for test_name in args.tests:
                if OutputControl.is_verbose():
                    print(f"\n【运行测试: {test_name}】")
                run_test(test_cases[test_name], test_name)
    except Exception as e:
        print(f"\n发生错误: {str(e)}")
    finally:
        # 打印测试统计
        STATS.print_summary()
        # 如果指定了输出文件路径,则导出结果
        if args.output:
            STATS.export_results(args.output)