yangdx commited on
Commit
e83d6af
·
1 Parent(s): e5ea22e

Add generate API tests and enhance chat API tests

Browse files

- Add non-streaming generate API test
- Add streaming generate API test
- Add generate error handling tests
- Add generate performance stats test
- Add generate concurrent request test

Files changed (1) hide show
  1. test_lightrag_ollama_chat.py +294 -25
test_lightrag_ollama_chat.py CHANGED
@@ -108,7 +108,10 @@ DEFAULT_CONFIG = {
108
  "max_retries": 3,
109
  "retry_delay": 1,
110
  },
111
- "test_cases": {"basic": {"query": "唐僧有几个徒弟"}},
 
 
 
112
  }
113
 
114
 
@@ -174,22 +177,27 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2)
174
  CONFIG = load_config()
175
 
176
 
177
- def get_base_url() -> str:
178
- """Return the base URL"""
 
 
 
 
 
179
  server = CONFIG["server"]
180
- return f"http://{server['host']}:{server['port']}/api/chat"
181
 
182
 
183
- def create_request_data(
184
  content: str, stream: bool = False, model: str = None
185
  ) -> Dict[str, Any]:
186
- """Create basic request data
187
  Args:
188
  content: User message content
189
  stream: Whether to use streaming response
190
  model: Model name
191
  Returns:
192
- Dictionary containing complete request data
193
  """
194
  return {
195
  "model": model or CONFIG["server"]["model"],
@@ -197,6 +205,34 @@ def create_request_data(
197
  "stream": stream,
198
  }
199
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  # Global test statistics
202
  STATS = TestStats()
@@ -219,10 +255,10 @@ def run_test(func: Callable, name: str) -> None:
219
  raise
220
 
221
 
222
- def test_non_stream_chat():
223
  """Test non-streaming call to /api/chat endpoint"""
224
  url = get_base_url()
225
- data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
226
 
227
  # Send request
228
  response = make_request(url, data)
@@ -239,7 +275,7 @@ def test_non_stream_chat():
239
  )
240
 
241
 
242
- def test_stream_chat():
243
  """Test streaming call to /api/chat endpoint
244
 
245
  Use JSON Lines format to process streaming responses, each line is a complete JSON object.
@@ -258,7 +294,7 @@ def test_stream_chat():
258
  The last message will contain performance statistics, with done set to true.
259
  """
260
  url = get_base_url()
261
- data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
262
 
263
  # Send request and get streaming response
264
  response = make_request(url, data, stream=True)
@@ -295,7 +331,7 @@ def test_stream_chat():
295
  print()
296
 
297
 
298
- def test_query_modes():
299
  """Test different query mode prefixes
300
 
301
  Supported query modes:
@@ -313,7 +349,7 @@ def test_query_modes():
313
  for mode in modes:
314
  if OutputControl.is_verbose():
315
  print(f"\n=== Testing /{mode} mode ===")
316
- data = create_request_data(
317
  f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
318
  )
319
 
@@ -354,7 +390,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
354
  return error_data.get(error_type, error_data["empty_messages"])
355
 
356
 
357
- def test_stream_error_handling():
358
  """Test error handling for streaming responses
359
 
360
  Test scenarios:
@@ -400,7 +436,7 @@ def test_stream_error_handling():
400
  response.close()
401
 
402
 
403
- def test_error_handling():
404
  """Test error handling for non-streaming responses
405
 
406
  Test scenarios:
@@ -447,6 +483,228 @@ def test_error_handling():
447
  print_json_response(response.json(), "Error message")
448
 
449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  def get_test_cases() -> Dict[str, Callable]:
451
  """Get all available test cases
452
  Returns:
@@ -458,6 +716,13 @@ def get_test_cases() -> Dict[str, Callable]:
458
  "modes": test_query_modes,
459
  "errors": test_error_handling,
460
  "stream_errors": test_stream_error_handling,
 
 
 
 
 
 
 
461
  }
462
 
463
 
@@ -544,18 +809,22 @@ if __name__ == "__main__":
544
  if "all" in args.tests:
545
  # Run all tests
546
  if OutputControl.is_verbose():
547
- print("\n【Basic Functionality Tests】")
548
- run_test(test_non_stream_chat, "Non-streaming Call Test")
549
- run_test(test_stream_chat, "Streaming Call Test")
550
-
551
- if OutputControl.is_verbose():
552
- print("\n【Query Mode Tests】")
553
- run_test(test_query_modes, "Query Mode Test")
554
 
555
  if OutputControl.is_verbose():
556
- print("\n【Error Handling Tests】")
557
- run_test(test_error_handling, "Error Handling Test")
558
- run_test(test_stream_error_handling, "Streaming Error Handling Test")
 
 
 
 
 
559
  else:
560
  # Run specified tests
561
  for test_name in args.tests:
 
108
  "max_retries": 3,
109
  "retry_delay": 1,
110
  },
111
+ "test_cases": {
112
+ "basic": {"query": "唐僧有几个徒弟"},
113
+ "generate": {"query": "电视剧西游记导演是谁"}
114
+ },
115
  }
116
 
117
 
 
177
  CONFIG = load_config()
178
 
179
 
180
+ def get_base_url(endpoint: str = "chat") -> str:
181
+ """Return the base URL for specified endpoint
182
+ Args:
183
+ endpoint: API endpoint name (chat or generate)
184
+ Returns:
185
+ Complete URL for the endpoint
186
+ """
187
  server = CONFIG["server"]
188
+ return f"http://{server['host']}:{server['port']}/api/{endpoint}"
189
 
190
 
191
+ def create_chat_request_data(
192
  content: str, stream: bool = False, model: str = None
193
  ) -> Dict[str, Any]:
194
+ """Create chat request data
195
  Args:
196
  content: User message content
197
  stream: Whether to use streaming response
198
  model: Model name
199
  Returns:
200
+ Dictionary containing complete chat request data
201
  """
202
  return {
203
  "model": model or CONFIG["server"]["model"],
 
205
  "stream": stream,
206
  }
207
 
208
+ def create_generate_request_data(
209
+ prompt: str,
210
+ system: str = None,
211
+ stream: bool = False,
212
+ model: str = None,
213
+ options: Dict[str, Any] = None
214
+ ) -> Dict[str, Any]:
215
+ """Create generate request data
216
+ Args:
217
+ prompt: Generation prompt
218
+ system: System prompt
219
+ stream: Whether to use streaming response
220
+ model: Model name
221
+ options: Additional options
222
+ Returns:
223
+ Dictionary containing complete generate request data
224
+ """
225
+ data = {
226
+ "model": model or CONFIG["server"]["model"],
227
+ "prompt": prompt,
228
+ "stream": stream
229
+ }
230
+ if system:
231
+ data["system"] = system
232
+ if options:
233
+ data["options"] = options
234
+ return data
235
+
236
 
237
  # Global test statistics
238
  STATS = TestStats()
 
255
  raise
256
 
257
 
258
+ def test_non_stream_chat() -> None:
259
  """Test non-streaming call to /api/chat endpoint"""
260
  url = get_base_url()
261
+ data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
262
 
263
  # Send request
264
  response = make_request(url, data)
 
275
  )
276
 
277
 
278
+ def test_stream_chat() -> None:
279
  """Test streaming call to /api/chat endpoint
280
 
281
  Use JSON Lines format to process streaming responses, each line is a complete JSON object.
 
294
  The last message will contain performance statistics, with done set to true.
295
  """
296
  url = get_base_url()
297
+ data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
298
 
299
  # Send request and get streaming response
300
  response = make_request(url, data, stream=True)
 
331
  print()
332
 
333
 
334
+ def test_query_modes() -> None:
335
  """Test different query mode prefixes
336
 
337
  Supported query modes:
 
349
  for mode in modes:
350
  if OutputControl.is_verbose():
351
  print(f"\n=== Testing /{mode} mode ===")
352
+ data = create_chat_request_data(
353
  f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
354
  )
355
 
 
390
  return error_data.get(error_type, error_data["empty_messages"])
391
 
392
 
393
+ def test_stream_error_handling() -> None:
394
  """Test error handling for streaming responses
395
 
396
  Test scenarios:
 
436
  response.close()
437
 
438
 
439
+ def test_error_handling() -> None:
440
  """Test error handling for non-streaming responses
441
 
442
  Test scenarios:
 
483
  print_json_response(response.json(), "Error message")
484
 
485
 
486
+ def test_non_stream_generate() -> None:
487
+ """Test non-streaming call to /api/generate endpoint"""
488
+ url = get_base_url("generate")
489
+ data = create_generate_request_data(
490
+ CONFIG["test_cases"]["generate"]["query"],
491
+ stream=False
492
+ )
493
+
494
+ # Send request
495
+ response = make_request(url, data)
496
+
497
+ # Print response
498
+ if OutputControl.is_verbose():
499
+ print("\n=== Non-streaming generate response ===")
500
+ response_json = response.json()
501
+
502
+ # Print response content
503
+ print_json_response(
504
+ {
505
+ "model": response_json["model"],
506
+ "response": response_json["response"],
507
+ "done": response_json["done"]
508
+ },
509
+ "Response content"
510
+ )
511
+
512
+ def test_stream_generate() -> None:
513
+ """Test streaming call to /api/generate endpoint"""
514
+ url = get_base_url("generate")
515
+ data = create_generate_request_data(
516
+ CONFIG["test_cases"]["generate"]["query"],
517
+ stream=True
518
+ )
519
+
520
+ # Send request and get streaming response
521
+ response = make_request(url, data, stream=True)
522
+
523
+ if OutputControl.is_verbose():
524
+ print("\n=== Streaming generate response ===")
525
+ output_buffer = []
526
+ try:
527
+ for line in response.iter_lines():
528
+ if line: # Skip empty lines
529
+ try:
530
+ # Decode and parse JSON
531
+ data = json.loads(line.decode("utf-8"))
532
+ if data.get("done", True): # If it's the completion marker
533
+ if "total_duration" in data: # Final performance statistics message
534
+ break
535
+ else: # Normal content message
536
+ content = data.get("response", "")
537
+ if content: # Only collect non-empty content
538
+ output_buffer.append(content)
539
+ print(content, end="", flush=True) # Print content in real-time
540
+ except json.JSONDecodeError:
541
+ print("Error decoding JSON from response line")
542
+ finally:
543
+ response.close() # Ensure the response connection is closed
544
+
545
+ # Print a newline
546
+ print()
547
+
548
+ def test_generate_with_system() -> None:
549
+ """Test generate with system prompt"""
550
+ url = get_base_url("generate")
551
+ data = create_generate_request_data(
552
+ CONFIG["test_cases"]["generate"]["query"],
553
+ system="你是一个知识渊博的助手",
554
+ stream=False
555
+ )
556
+
557
+ # Send request
558
+ response = make_request(url, data)
559
+
560
+ # Print response
561
+ if OutputControl.is_verbose():
562
+ print("\n=== Generate with system prompt response ===")
563
+ response_json = response.json()
564
+
565
+ # Print response content
566
+ print_json_response(
567
+ {
568
+ "model": response_json["model"],
569
+ "response": response_json["response"],
570
+ "done": response_json["done"]
571
+ },
572
+ "Response content"
573
+ )
574
+
575
+ def test_generate_error_handling() -> None:
576
+ """Test error handling for generate endpoint"""
577
+ url = get_base_url("generate")
578
+
579
+ # Test empty prompt
580
+ if OutputControl.is_verbose():
581
+ print("\n=== Testing empty prompt ===")
582
+ data = create_generate_request_data("", stream=False)
583
+ response = make_request(url, data)
584
+ print(f"Status code: {response.status_code}")
585
+ print_json_response(response.json(), "Error message")
586
+
587
+ # Test invalid options
588
+ if OutputControl.is_verbose():
589
+ print("\n=== Testing invalid options ===")
590
+ data = create_generate_request_data(
591
+ CONFIG["test_cases"]["basic"]["query"],
592
+ options={"invalid_option": "value"},
593
+ stream=False
594
+ )
595
+ response = make_request(url, data)
596
+ print(f"Status code: {response.status_code}")
597
+ print_json_response(response.json(), "Error message")
598
+
599
+ # Test very long input
600
+ if OutputControl.is_verbose():
601
+ print("\n=== Testing very long input ===")
602
+ long_text = "测试" * 10000 # Create a very long input
603
+ data = create_generate_request_data(long_text, stream=False)
604
+ response = make_request(url, data)
605
+ print(f"Status code: {response.status_code}")
606
+ print_json_response(response.json(), "Error message")
607
+
608
+ def test_generate_performance_stats() -> None:
609
+ """Test performance statistics in generate response"""
610
+ url = get_base_url("generate")
611
+
612
+ # Test with different length inputs to verify token counting
613
+ inputs = [
614
+ "你好", # Short Chinese
615
+ "Hello world", # Short English
616
+ "这是一个较长的中文输入,用来测试token数量的估算是否准确。", # Medium Chinese
617
+ "This is a longer English input that will be used to test the accuracy of token count estimation." # Medium English
618
+ ]
619
+
620
+ for test_input in inputs:
621
+ if OutputControl.is_verbose():
622
+ print(f"\n=== Testing performance stats with input: {test_input} ===")
623
+ data = create_generate_request_data(test_input, stream=False)
624
+ response = make_request(url, data)
625
+ response_json = response.json()
626
+
627
+ # Verify performance statistics exist and are reasonable
628
+ stats = {
629
+ "total_duration": response_json.get("total_duration"),
630
+ "prompt_eval_count": response_json.get("prompt_eval_count"),
631
+ "prompt_eval_duration": response_json.get("prompt_eval_duration"),
632
+ "eval_count": response_json.get("eval_count"),
633
+ "eval_duration": response_json.get("eval_duration")
634
+ }
635
+ print_json_response(stats, "Performance statistics")
636
+
637
+ def test_generate_concurrent() -> None:
638
+ """Test concurrent generate requests"""
639
+ import asyncio
640
+ import aiohttp
641
+ from contextlib import asynccontextmanager
642
+
643
+ @asynccontextmanager
644
+ async def get_session():
645
+ async with aiohttp.ClientSession() as session:
646
+ yield session
647
+
648
+ async def make_request(session, prompt: str):
649
+ url = get_base_url("generate")
650
+ data = create_generate_request_data(prompt, stream=False)
651
+ try:
652
+ async with session.post(url, json=data) as response:
653
+ return await response.json()
654
+ except Exception as e:
655
+ return {"error": str(e)}
656
+
657
+ async def run_concurrent_requests():
658
+ prompts = [
659
+ "第一个问题",
660
+ "第二个问题",
661
+ "第三个问题",
662
+ "第四个问题",
663
+ "第五个问题"
664
+ ]
665
+
666
+ async with get_session() as session:
667
+ tasks = [make_request(session, prompt) for prompt in prompts]
668
+ results = await asyncio.gather(*tasks)
669
+ return results
670
+
671
+ if OutputControl.is_verbose():
672
+ print("\n=== Testing concurrent generate requests ===")
673
+
674
+ # Run concurrent requests
675
+ results = asyncio.run(run_concurrent_requests())
676
+
677
+ # Print results
678
+ for i, result in enumerate(results, 1):
679
+ print(f"\nRequest {i} result:")
680
+ print_json_response(result)
681
+
682
+ def test_generate_query_modes() -> None:
683
+ """Test different query mode prefixes for generate endpoint"""
684
+ url = get_base_url("generate")
685
+ modes = ["local", "global", "naive", "hybrid", "mix"]
686
+
687
+ for mode in modes:
688
+ if OutputControl.is_verbose():
689
+ print(f"\n=== Testing /{mode} mode for generate ===")
690
+ data = create_generate_request_data(
691
+ f"/{mode} {CONFIG['test_cases']['generate']['query']}",
692
+ stream=False
693
+ )
694
+
695
+ # Send request
696
+ response = make_request(url, data)
697
+ response_json = response.json()
698
+
699
+ # Print response content
700
+ print_json_response(
701
+ {
702
+ "model": response_json["model"],
703
+ "response": response_json["response"],
704
+ "done": response_json["done"]
705
+ }
706
+ )
707
+
708
  def get_test_cases() -> Dict[str, Callable]:
709
  """Get all available test cases
710
  Returns:
 
716
  "modes": test_query_modes,
717
  "errors": test_error_handling,
718
  "stream_errors": test_stream_error_handling,
719
+ "non_stream_generate": test_non_stream_generate,
720
+ "stream_generate": test_stream_generate,
721
+ "generate_with_system": test_generate_with_system,
722
+ "generate_modes": test_generate_query_modes,
723
+ "generate_errors": test_generate_error_handling,
724
+ "generate_stats": test_generate_performance_stats,
725
+ "generate_concurrent": test_generate_concurrent
726
  }
727
 
728
 
 
809
  if "all" in args.tests:
810
  # Run all tests
811
  if OutputControl.is_verbose():
812
+ print("\n【Chat API Tests】")
813
+ run_test(test_non_stream_chat, "Non-streaming Chat Test")
814
+ run_test(test_stream_chat, "Streaming Chat Test")
815
+ run_test(test_query_modes, "Chat Query Mode Test")
816
+ run_test(test_error_handling, "Chat Error Handling Test")
817
+ run_test(test_stream_error_handling, "Chat Streaming Error Test")
 
818
 
819
  if OutputControl.is_verbose():
820
+ print("\n【Generate API Tests】")
821
+ run_test(test_non_stream_generate, "Non-streaming Generate Test")
822
+ run_test(test_stream_generate, "Streaming Generate Test")
823
+ run_test(test_generate_with_system, "Generate with System Prompt Test")
824
+ run_test(test_generate_query_modes, "Generate Query Mode Test")
825
+ run_test(test_generate_error_handling, "Generate Error Handling Test")
826
+ run_test(test_generate_performance_stats, "Generate Performance Stats Test")
827
+ run_test(test_generate_concurrent, "Generate Concurrent Test")
828
  else:
829
  # Run specified tests
830
  for test_name in args.tests: