zrguo commited on
Commit
dc1567d
·
unverified ·
2 Parent(s): ee9eb3b dcd4769

Merge pull request #644 from danielaskdd/Add-Ollama-generate-API-support

Browse files
lightrag/api/README.md CHANGED
@@ -94,6 +94,7 @@ For example, chat message "/mix 唐僧有几个徒弟" will trigger a mix mode q
94
 
95
  After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
96
 
 
97
 
98
  ## Configuration
99
 
 
94
 
95
  After starting the lightrag-server, you can add an Ollama-type connection in the Open WebUI admin pannel. And then a model named lightrag:latest will appear in Open WebUI's model management interface. Users can then send queries to LightRAG through the chat interface.
96
 
97
+ To prevent Open WebUI from using LightRAG when generating conversation titles, go to Admin Panel > Interface > Set Task Model and change both Local Models and External Models to any option except "Current Model".
98
 
99
  ## Configuration
100
 
lightrag/api/lightrag_server.py CHANGED
@@ -533,6 +533,7 @@ class OllamaChatRequest(BaseModel):
533
  messages: List[OllamaMessage]
534
  stream: bool = True # Default to streaming mode
535
  options: Optional[Dict[str, Any]] = None
 
536
 
537
 
538
  class OllamaChatResponse(BaseModel):
@@ -542,6 +543,28 @@ class OllamaChatResponse(BaseModel):
542
  done: bool
543
 
544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
  class OllamaVersionResponse(BaseModel):
546
  version: str
547
 
@@ -1417,6 +1440,145 @@ def create_app(args):
1417
 
1418
  return query, SearchMode.hybrid
1419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1420
  @app.post("/api/chat")
1421
  async def chat(raw_request: Request, request: OllamaChatRequest):
1422
  """Handle chat completion requests"""
@@ -1429,16 +1591,12 @@ def create_app(args):
1429
  # Get the last message as query
1430
  query = messages[-1].content
1431
 
1432
- # 解析查询模式
1433
  cleaned_query, mode = parse_query_mode(query)
1434
 
1435
- # 开始计时
1436
  start_time = time.time_ns()
1437
-
1438
- # 计算输入token数量
1439
  prompt_tokens = estimate_tokens(cleaned_query)
1440
 
1441
- # 调用RAG进行查询
1442
  query_param = QueryParam(
1443
  mode=mode, stream=request.stream, only_need_context=False
1444
  )
@@ -1549,7 +1707,21 @@ def create_app(args):
1549
  )
1550
  else:
1551
  first_chunk_time = time.time_ns()
1552
- response_text = await rag.aquery(cleaned_query, param=query_param)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1553
  last_chunk_time = time.time_ns()
1554
 
1555
  if not response_text:
 
533
  messages: List[OllamaMessage]
534
  stream: bool = True # Default to streaming mode
535
  options: Optional[Dict[str, Any]] = None
536
+ system: Optional[str] = None
537
 
538
 
539
  class OllamaChatResponse(BaseModel):
 
543
  done: bool
544
 
545
 
546
+ class OllamaGenerateRequest(BaseModel):
547
+ model: str = LIGHTRAG_MODEL
548
+ prompt: str
549
+ system: Optional[str] = None
550
+ stream: bool = False
551
+ options: Optional[Dict[str, Any]] = None
552
+
553
+
554
+ class OllamaGenerateResponse(BaseModel):
555
+ model: str
556
+ created_at: str
557
+ response: str
558
+ done: bool
559
+ context: Optional[List[int]]
560
+ total_duration: Optional[int]
561
+ load_duration: Optional[int]
562
+ prompt_eval_count: Optional[int]
563
+ prompt_eval_duration: Optional[int]
564
+ eval_count: Optional[int]
565
+ eval_duration: Optional[int]
566
+
567
+
568
  class OllamaVersionResponse(BaseModel):
569
  version: str
570
 
 
1440
 
1441
  return query, SearchMode.hybrid
1442
 
1443
+ @app.post("/api/generate")
1444
+ async def generate(raw_request: Request, request: OllamaGenerateRequest):
1445
+ """Handle generate completion requests"""
1446
+ try:
1447
+ query = request.prompt
1448
+ start_time = time.time_ns()
1449
+ prompt_tokens = estimate_tokens(query)
1450
+
1451
+ if request.system:
1452
+ rag.llm_model_kwargs["system_prompt"] = request.system
1453
+
1454
+ if request.stream:
1455
+ from fastapi.responses import StreamingResponse
1456
+
1457
+ response = await rag.llm_model_func(
1458
+ query, stream=True, **rag.llm_model_kwargs
1459
+ )
1460
+
1461
+ async def stream_generator():
1462
+ try:
1463
+ first_chunk_time = None
1464
+ last_chunk_time = None
1465
+ total_response = ""
1466
+
1467
+ # Ensure response is an async generator
1468
+ if isinstance(response, str):
1469
+ # If it's a string, send in two parts
1470
+ first_chunk_time = time.time_ns()
1471
+ last_chunk_time = first_chunk_time
1472
+ total_response = response
1473
+
1474
+ data = {
1475
+ "model": LIGHTRAG_MODEL,
1476
+ "created_at": LIGHTRAG_CREATED_AT,
1477
+ "response": response,
1478
+ "done": False,
1479
+ }
1480
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
1481
+
1482
+ completion_tokens = estimate_tokens(total_response)
1483
+ total_time = last_chunk_time - start_time
1484
+ prompt_eval_time = first_chunk_time - start_time
1485
+ eval_time = last_chunk_time - first_chunk_time
1486
+
1487
+ data = {
1488
+ "model": LIGHTRAG_MODEL,
1489
+ "created_at": LIGHTRAG_CREATED_AT,
1490
+ "done": True,
1491
+ "total_duration": total_time,
1492
+ "load_duration": 0,
1493
+ "prompt_eval_count": prompt_tokens,
1494
+ "prompt_eval_duration": prompt_eval_time,
1495
+ "eval_count": completion_tokens,
1496
+ "eval_duration": eval_time,
1497
+ }
1498
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
1499
+ else:
1500
+ async for chunk in response:
1501
+ if chunk:
1502
+ if first_chunk_time is None:
1503
+ first_chunk_time = time.time_ns()
1504
+
1505
+ last_chunk_time = time.time_ns()
1506
+
1507
+ total_response += chunk
1508
+ data = {
1509
+ "model": LIGHTRAG_MODEL,
1510
+ "created_at": LIGHTRAG_CREATED_AT,
1511
+ "response": chunk,
1512
+ "done": False,
1513
+ }
1514
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
1515
+
1516
+ completion_tokens = estimate_tokens(total_response)
1517
+ total_time = last_chunk_time - start_time
1518
+ prompt_eval_time = first_chunk_time - start_time
1519
+ eval_time = last_chunk_time - first_chunk_time
1520
+
1521
+ data = {
1522
+ "model": LIGHTRAG_MODEL,
1523
+ "created_at": LIGHTRAG_CREATED_AT,
1524
+ "done": True,
1525
+ "total_duration": total_time,
1526
+ "load_duration": 0,
1527
+ "prompt_eval_count": prompt_tokens,
1528
+ "prompt_eval_duration": prompt_eval_time,
1529
+ "eval_count": completion_tokens,
1530
+ "eval_duration": eval_time,
1531
+ }
1532
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
1533
+ return
1534
+
1535
+ except Exception as e:
1536
+ logging.error(f"Error in stream_generator: {str(e)}")
1537
+ raise
1538
+
1539
+ return StreamingResponse(
1540
+ stream_generator(),
1541
+ media_type="application/x-ndjson",
1542
+ headers={
1543
+ "Cache-Control": "no-cache",
1544
+ "Connection": "keep-alive",
1545
+ "Content-Type": "application/x-ndjson",
1546
+ "Access-Control-Allow-Origin": "*",
1547
+ "Access-Control-Allow-Methods": "POST, OPTIONS",
1548
+ "Access-Control-Allow-Headers": "Content-Type",
1549
+ },
1550
+ )
1551
+ else:
1552
+ first_chunk_time = time.time_ns()
1553
+ response_text = await rag.llm_model_func(
1554
+ query, stream=False, **rag.llm_model_kwargs
1555
+ )
1556
+ last_chunk_time = time.time_ns()
1557
+
1558
+ if not response_text:
1559
+ response_text = "No response generated"
1560
+
1561
+ completion_tokens = estimate_tokens(str(response_text))
1562
+ total_time = last_chunk_time - start_time
1563
+ prompt_eval_time = first_chunk_time - start_time
1564
+ eval_time = last_chunk_time - first_chunk_time
1565
+
1566
+ return {
1567
+ "model": LIGHTRAG_MODEL,
1568
+ "created_at": LIGHTRAG_CREATED_AT,
1569
+ "response": str(response_text),
1570
+ "done": True,
1571
+ "total_duration": total_time,
1572
+ "load_duration": 0,
1573
+ "prompt_eval_count": prompt_tokens,
1574
+ "prompt_eval_duration": prompt_eval_time,
1575
+ "eval_count": completion_tokens,
1576
+ "eval_duration": eval_time,
1577
+ }
1578
+ except Exception as e:
1579
+ trace_exception(e)
1580
+ raise HTTPException(status_code=500, detail=str(e))
1581
+
1582
  @app.post("/api/chat")
1583
  async def chat(raw_request: Request, request: OllamaChatRequest):
1584
  """Handle chat completion requests"""
 
1591
  # Get the last message as query
1592
  query = messages[-1].content
1593
 
1594
+ # Check for query prefix
1595
  cleaned_query, mode = parse_query_mode(query)
1596
 
 
1597
  start_time = time.time_ns()
 
 
1598
  prompt_tokens = estimate_tokens(cleaned_query)
1599
 
 
1600
  query_param = QueryParam(
1601
  mode=mode, stream=request.stream, only_need_context=False
1602
  )
 
1707
  )
1708
  else:
1709
  first_chunk_time = time.time_ns()
1710
+
1711
+ # Determine if the request is from Open WebUI's session title and session keyword generation task
1712
+ match_result = re.search(
1713
+ r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
1714
+ )
1715
+ if match_result:
1716
+ if request.system:
1717
+ rag.llm_model_kwargs["system_prompt"] = request.system
1718
+
1719
+ response_text = await rag.llm_model_func(
1720
+ cleaned_query, stream=False, **rag.llm_model_kwargs
1721
+ )
1722
+ else:
1723
+ response_text = await rag.aquery(cleaned_query, param=query_param)
1724
+
1725
  last_chunk_time = time.time_ns()
1726
 
1727
  if not response_text:
test_lightrag_ollama_chat.py CHANGED
@@ -108,7 +108,10 @@ DEFAULT_CONFIG = {
108
  "max_retries": 3,
109
  "retry_delay": 1,
110
  },
111
- "test_cases": {"basic": {"query": "唐僧有几个徒弟"}},
 
 
 
112
  }
113
 
114
 
@@ -174,22 +177,27 @@ def print_json_response(data: Dict[str, Any], title: str = "", indent: int = 2)
174
  CONFIG = load_config()
175
 
176
 
177
- def get_base_url() -> str:
178
- """Return the base URL"""
 
 
 
 
 
179
  server = CONFIG["server"]
180
- return f"http://{server['host']}:{server['port']}/api/chat"
181
 
182
 
183
- def create_request_data(
184
  content: str, stream: bool = False, model: str = None
185
  ) -> Dict[str, Any]:
186
- """Create basic request data
187
  Args:
188
  content: User message content
189
  stream: Whether to use streaming response
190
  model: Model name
191
  Returns:
192
- Dictionary containing complete request data
193
  """
194
  return {
195
  "model": model or CONFIG["server"]["model"],
@@ -198,6 +206,35 @@ def create_request_data(
198
  }
199
 
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  # Global test statistics
202
  STATS = TestStats()
203
 
@@ -219,10 +256,12 @@ def run_test(func: Callable, name: str) -> None:
219
  raise
220
 
221
 
222
- def test_non_stream_chat():
223
  """Test non-streaming call to /api/chat endpoint"""
224
  url = get_base_url()
225
- data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
 
 
226
 
227
  # Send request
228
  response = make_request(url, data)
@@ -239,7 +278,7 @@ def test_non_stream_chat():
239
  )
240
 
241
 
242
- def test_stream_chat():
243
  """Test streaming call to /api/chat endpoint
244
 
245
  Use JSON Lines format to process streaming responses, each line is a complete JSON object.
@@ -258,7 +297,7 @@ def test_stream_chat():
258
  The last message will contain performance statistics, with done set to true.
259
  """
260
  url = get_base_url()
261
- data = create_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
262
 
263
  # Send request and get streaming response
264
  response = make_request(url, data, stream=True)
@@ -295,7 +334,7 @@ def test_stream_chat():
295
  print()
296
 
297
 
298
- def test_query_modes():
299
  """Test different query mode prefixes
300
 
301
  Supported query modes:
@@ -313,7 +352,7 @@ def test_query_modes():
313
  for mode in modes:
314
  if OutputControl.is_verbose():
315
  print(f"\n=== Testing /{mode} mode ===")
316
- data = create_request_data(
317
  f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
318
  )
319
 
@@ -354,7 +393,7 @@ def create_error_test_data(error_type: str) -> Dict[str, Any]:
354
  return error_data.get(error_type, error_data["empty_messages"])
355
 
356
 
357
- def test_stream_error_handling():
358
  """Test error handling for streaming responses
359
 
360
  Test scenarios:
@@ -400,7 +439,7 @@ def test_stream_error_handling():
400
  response.close()
401
 
402
 
403
- def test_error_handling():
404
  """Test error handling for non-streaming responses
405
 
406
  Test scenarios:
@@ -447,6 +486,165 @@ def test_error_handling():
447
  print_json_response(response.json(), "Error message")
448
 
449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
  def get_test_cases() -> Dict[str, Callable]:
451
  """Get all available test cases
452
  Returns:
@@ -458,6 +656,11 @@ def get_test_cases() -> Dict[str, Callable]:
458
  "modes": test_query_modes,
459
  "errors": test_error_handling,
460
  "stream_errors": test_stream_error_handling,
 
 
 
 
 
461
  }
462
 
463
 
@@ -544,18 +747,20 @@ if __name__ == "__main__":
544
  if "all" in args.tests:
545
  # Run all tests
546
  if OutputControl.is_verbose():
547
- print("\n【Basic Functionality Tests】")
548
- run_test(test_non_stream_chat, "Non-streaming Call Test")
549
- run_test(test_stream_chat, "Streaming Call Test")
550
-
551
- if OutputControl.is_verbose():
552
- print("\n【Query Mode Tests】")
553
- run_test(test_query_modes, "Query Mode Test")
554
 
555
  if OutputControl.is_verbose():
556
- print("\n【Error Handling Tests】")
557
- run_test(test_error_handling, "Error Handling Test")
558
- run_test(test_stream_error_handling, "Streaming Error Handling Test")
 
 
 
559
  else:
560
  # Run specified tests
561
  for test_name in args.tests:
 
108
  "max_retries": 3,
109
  "retry_delay": 1,
110
  },
111
+ "test_cases": {
112
+ "basic": {"query": "唐僧有几个徒弟"},
113
+ "generate": {"query": "电视剧西游记导演是谁"},
114
+ },
115
  }
116
 
117
 
 
177
  CONFIG = load_config()
178
 
179
 
180
+ def get_base_url(endpoint: str = "chat") -> str:
181
+ """Return the base URL for specified endpoint
182
+ Args:
183
+ endpoint: API endpoint name (chat or generate)
184
+ Returns:
185
+ Complete URL for the endpoint
186
+ """
187
  server = CONFIG["server"]
188
+ return f"http://{server['host']}:{server['port']}/api/{endpoint}"
189
 
190
 
191
+ def create_chat_request_data(
192
  content: str, stream: bool = False, model: str = None
193
  ) -> Dict[str, Any]:
194
+ """Create chat request data
195
  Args:
196
  content: User message content
197
  stream: Whether to use streaming response
198
  model: Model name
199
  Returns:
200
+ Dictionary containing complete chat request data
201
  """
202
  return {
203
  "model": model or CONFIG["server"]["model"],
 
206
  }
207
 
208
 
209
+ def create_generate_request_data(
210
+ prompt: str,
211
+ system: str = None,
212
+ stream: bool = False,
213
+ model: str = None,
214
+ options: Dict[str, Any] = None,
215
+ ) -> Dict[str, Any]:
216
+ """Create generate request data
217
+ Args:
218
+ prompt: Generation prompt
219
+ system: System prompt
220
+ stream: Whether to use streaming response
221
+ model: Model name
222
+ options: Additional options
223
+ Returns:
224
+ Dictionary containing complete generate request data
225
+ """
226
+ data = {
227
+ "model": model or CONFIG["server"]["model"],
228
+ "prompt": prompt,
229
+ "stream": stream,
230
+ }
231
+ if system:
232
+ data["system"] = system
233
+ if options:
234
+ data["options"] = options
235
+ return data
236
+
237
+
238
  # Global test statistics
239
  STATS = TestStats()
240
 
 
256
  raise
257
 
258
 
259
+ def test_non_stream_chat() -> None:
260
  """Test non-streaming call to /api/chat endpoint"""
261
  url = get_base_url()
262
+ data = create_chat_request_data(
263
+ CONFIG["test_cases"]["basic"]["query"], stream=False
264
+ )
265
 
266
  # Send request
267
  response = make_request(url, data)
 
278
  )
279
 
280
 
281
+ def test_stream_chat() -> None:
282
  """Test streaming call to /api/chat endpoint
283
 
284
  Use JSON Lines format to process streaming responses, each line is a complete JSON object.
 
297
  The last message will contain performance statistics, with done set to true.
298
  """
299
  url = get_base_url()
300
+ data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=True)
301
 
302
  # Send request and get streaming response
303
  response = make_request(url, data, stream=True)
 
334
  print()
335
 
336
 
337
+ def test_query_modes() -> None:
338
  """Test different query mode prefixes
339
 
340
  Supported query modes:
 
352
  for mode in modes:
353
  if OutputControl.is_verbose():
354
  print(f"\n=== Testing /{mode} mode ===")
355
+ data = create_chat_request_data(
356
  f"/{mode} {CONFIG['test_cases']['basic']['query']}", stream=False
357
  )
358
 
 
393
  return error_data.get(error_type, error_data["empty_messages"])
394
 
395
 
396
+ def test_stream_error_handling() -> None:
397
  """Test error handling for streaming responses
398
 
399
  Test scenarios:
 
439
  response.close()
440
 
441
 
442
+ def test_error_handling() -> None:
443
  """Test error handling for non-streaming responses
444
 
445
  Test scenarios:
 
486
  print_json_response(response.json(), "Error message")
487
 
488
 
489
+ def test_non_stream_generate() -> None:
490
+ """Test non-streaming call to /api/generate endpoint"""
491
+ url = get_base_url("generate")
492
+ data = create_generate_request_data(
493
+ CONFIG["test_cases"]["generate"]["query"], stream=False
494
+ )
495
+
496
+ # Send request
497
+ response = make_request(url, data)
498
+
499
+ # Print response
500
+ if OutputControl.is_verbose():
501
+ print("\n=== Non-streaming generate response ===")
502
+ response_json = response.json()
503
+
504
+ # Print response content
505
+ print_json_response(
506
+ {
507
+ "model": response_json["model"],
508
+ "response": response_json["response"],
509
+ "done": response_json["done"],
510
+ },
511
+ "Response content",
512
+ )
513
+
514
+
515
+ def test_stream_generate() -> None:
516
+ """Test streaming call to /api/generate endpoint"""
517
+ url = get_base_url("generate")
518
+ data = create_generate_request_data(
519
+ CONFIG["test_cases"]["generate"]["query"], stream=True
520
+ )
521
+
522
+ # Send request and get streaming response
523
+ response = make_request(url, data, stream=True)
524
+
525
+ if OutputControl.is_verbose():
526
+ print("\n=== Streaming generate response ===")
527
+ output_buffer = []
528
+ try:
529
+ for line in response.iter_lines():
530
+ if line: # Skip empty lines
531
+ try:
532
+ # Decode and parse JSON
533
+ data = json.loads(line.decode("utf-8"))
534
+ if data.get("done", True): # If it's the completion marker
535
+ if (
536
+ "total_duration" in data
537
+ ): # Final performance statistics message
538
+ break
539
+ else: # Normal content message
540
+ content = data.get("response", "")
541
+ if content: # Only collect non-empty content
542
+ output_buffer.append(content)
543
+ print(
544
+ content, end="", flush=True
545
+ ) # Print content in real-time
546
+ except json.JSONDecodeError:
547
+ print("Error decoding JSON from response line")
548
+ finally:
549
+ response.close() # Ensure the response connection is closed
550
+
551
+ # Print a newline
552
+ print()
553
+
554
+
555
+ def test_generate_with_system() -> None:
556
+ """Test generate with system prompt"""
557
+ url = get_base_url("generate")
558
+ data = create_generate_request_data(
559
+ CONFIG["test_cases"]["generate"]["query"],
560
+ system="你是一个知识渊博的助手",
561
+ stream=False,
562
+ )
563
+
564
+ # Send request
565
+ response = make_request(url, data)
566
+
567
+ # Print response
568
+ if OutputControl.is_verbose():
569
+ print("\n=== Generate with system prompt response ===")
570
+ response_json = response.json()
571
+
572
+ # Print response content
573
+ print_json_response(
574
+ {
575
+ "model": response_json["model"],
576
+ "response": response_json["response"],
577
+ "done": response_json["done"],
578
+ },
579
+ "Response content",
580
+ )
581
+
582
+
583
+ def test_generate_error_handling() -> None:
584
+ """Test error handling for generate endpoint"""
585
+ url = get_base_url("generate")
586
+
587
+ # Test empty prompt
588
+ if OutputControl.is_verbose():
589
+ print("\n=== Testing empty prompt ===")
590
+ data = create_generate_request_data("", stream=False)
591
+ response = make_request(url, data)
592
+ print(f"Status code: {response.status_code}")
593
+ print_json_response(response.json(), "Error message")
594
+
595
+ # Test invalid options
596
+ if OutputControl.is_verbose():
597
+ print("\n=== Testing invalid options ===")
598
+ data = create_generate_request_data(
599
+ CONFIG["test_cases"]["basic"]["query"],
600
+ options={"invalid_option": "value"},
601
+ stream=False,
602
+ )
603
+ response = make_request(url, data)
604
+ print(f"Status code: {response.status_code}")
605
+ print_json_response(response.json(), "Error message")
606
+
607
+
608
+ def test_generate_concurrent() -> None:
609
+ """Test concurrent generate requests"""
610
+ import asyncio
611
+ import aiohttp
612
+ from contextlib import asynccontextmanager
613
+
614
+ @asynccontextmanager
615
+ async def get_session():
616
+ async with aiohttp.ClientSession() as session:
617
+ yield session
618
+
619
+ async def make_request(session, prompt: str):
620
+ url = get_base_url("generate")
621
+ data = create_generate_request_data(prompt, stream=False)
622
+ try:
623
+ async with session.post(url, json=data) as response:
624
+ return await response.json()
625
+ except Exception as e:
626
+ return {"error": str(e)}
627
+
628
+ async def run_concurrent_requests():
629
+ prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
630
+
631
+ async with get_session() as session:
632
+ tasks = [make_request(session, prompt) for prompt in prompts]
633
+ results = await asyncio.gather(*tasks)
634
+ return results
635
+
636
+ if OutputControl.is_verbose():
637
+ print("\n=== Testing concurrent generate requests ===")
638
+
639
+ # Run concurrent requests
640
+ results = asyncio.run(run_concurrent_requests())
641
+
642
+ # Print results
643
+ for i, result in enumerate(results, 1):
644
+ print(f"\nRequest {i} result:")
645
+ print_json_response(result)
646
+
647
+
648
  def get_test_cases() -> Dict[str, Callable]:
649
  """Get all available test cases
650
  Returns:
 
656
  "modes": test_query_modes,
657
  "errors": test_error_handling,
658
  "stream_errors": test_stream_error_handling,
659
+ "non_stream_generate": test_non_stream_generate,
660
+ "stream_generate": test_stream_generate,
661
+ "generate_with_system": test_generate_with_system,
662
+ "generate_errors": test_generate_error_handling,
663
+ "generate_concurrent": test_generate_concurrent,
664
  }
665
 
666
 
 
747
  if "all" in args.tests:
748
  # Run all tests
749
  if OutputControl.is_verbose():
750
+ print("\n【Chat API Tests】")
751
+ run_test(test_non_stream_chat, "Non-streaming Chat Test")
752
+ run_test(test_stream_chat, "Streaming Chat Test")
753
+ run_test(test_query_modes, "Chat Query Mode Test")
754
+ run_test(test_error_handling, "Chat Error Handling Test")
755
+ run_test(test_stream_error_handling, "Chat Streaming Error Test")
 
756
 
757
  if OutputControl.is_verbose():
758
+ print("\n【Generate API Tests】")
759
+ run_test(test_non_stream_generate, "Non-streaming Generate Test")
760
+ run_test(test_stream_generate, "Streaming Generate Test")
761
+ run_test(test_generate_with_system, "Generate with System Prompt Test")
762
+ run_test(test_generate_error_handling, "Generate Error Handling Test")
763
+ run_test(test_generate_concurrent, "Generate Concurrent Test")
764
  else:
765
  # Run specified tests
766
  for test_name in args.tests: