yangdx commited on
Commit
dcd4769
·
1 Parent(s): 4a5fb8c

Fix linting, remove redundant commentsr and clean up code for better readability

Browse files
lightrag/api/lightrag_server.py CHANGED
@@ -476,6 +476,7 @@ class OllamaChatResponse(BaseModel):
476
  message: OllamaMessage
477
  done: bool
478
 
 
479
  class OllamaGenerateRequest(BaseModel):
480
  model: str = LIGHTRAG_MODEL
481
  prompt: str
@@ -483,6 +484,7 @@ class OllamaGenerateRequest(BaseModel):
483
  stream: bool = False
484
  options: Optional[Dict[str, Any]] = None
485
 
 
486
  class OllamaGenerateResponse(BaseModel):
487
  model: str
488
  created_at: str
@@ -490,12 +492,13 @@ class OllamaGenerateResponse(BaseModel):
490
  done: bool
491
  context: Optional[List[int]]
492
  total_duration: Optional[int]
493
- load_duration: Optional[int]
494
  prompt_eval_count: Optional[int]
495
  prompt_eval_duration: Optional[int]
496
  eval_count: Optional[int]
497
  eval_duration: Optional[int]
498
 
 
499
  class OllamaVersionResponse(BaseModel):
500
  version: str
501
 
@@ -1262,52 +1265,45 @@ def create_app(args):
1262
  """Handle generate completion requests"""
1263
  try:
1264
  query = request.prompt
1265
-
1266
- # 开始计时
1267
  start_time = time.time_ns()
1268
-
1269
- # 计算输入token数量
1270
  prompt_tokens = estimate_tokens(query)
1271
-
1272
- # 直接使用 llm_model_func 进行查询
1273
  if request.system:
1274
  rag.llm_model_kwargs["system_prompt"] = request.system
1275
-
1276
  if request.stream:
1277
  from fastapi.responses import StreamingResponse
1278
-
1279
  response = await rag.llm_model_func(
1280
- query,
1281
- stream=True,
1282
- **rag.llm_model_kwargs
1283
  )
1284
-
1285
  async def stream_generator():
1286
  try:
1287
  first_chunk_time = None
1288
  last_chunk_time = None
1289
  total_response = ""
1290
-
1291
- # 处理响应
1292
  if isinstance(response, str):
1293
- # 如果是字符串,分两部分发送
1294
  first_chunk_time = time.time_ns()
1295
  last_chunk_time = first_chunk_time
1296
  total_response = response
1297
-
1298
  data = {
1299
  "model": LIGHTRAG_MODEL,
1300
  "created_at": LIGHTRAG_CREATED_AT,
1301
  "response": response,
1302
- "done": False
1303
  }
1304
  yield f"{json.dumps(data, ensure_ascii=False)}\n"
1305
-
1306
  completion_tokens = estimate_tokens(total_response)
1307
  total_time = last_chunk_time - start_time
1308
  prompt_eval_time = first_chunk_time - start_time
1309
  eval_time = last_chunk_time - first_chunk_time
1310
-
1311
  data = {
1312
  "model": LIGHTRAG_MODEL,
1313
  "created_at": LIGHTRAG_CREATED_AT,
@@ -1317,7 +1313,7 @@ def create_app(args):
1317
  "prompt_eval_count": prompt_tokens,
1318
  "prompt_eval_duration": prompt_eval_time,
1319
  "eval_count": completion_tokens,
1320
- "eval_duration": eval_time
1321
  }
1322
  yield f"{json.dumps(data, ensure_ascii=False)}\n"
1323
  else:
@@ -1325,23 +1321,23 @@ def create_app(args):
1325
  if chunk:
1326
  if first_chunk_time is None:
1327
  first_chunk_time = time.time_ns()
1328
-
1329
  last_chunk_time = time.time_ns()
1330
-
1331
  total_response += chunk
1332
  data = {
1333
  "model": LIGHTRAG_MODEL,
1334
  "created_at": LIGHTRAG_CREATED_AT,
1335
  "response": chunk,
1336
- "done": False
1337
  }
1338
  yield f"{json.dumps(data, ensure_ascii=False)}\n"
1339
-
1340
  completion_tokens = estimate_tokens(total_response)
1341
  total_time = last_chunk_time - start_time
1342
  prompt_eval_time = first_chunk_time - start_time
1343
  eval_time = last_chunk_time - first_chunk_time
1344
-
1345
  data = {
1346
  "model": LIGHTRAG_MODEL,
1347
  "created_at": LIGHTRAG_CREATED_AT,
@@ -1351,15 +1347,15 @@ def create_app(args):
1351
  "prompt_eval_count": prompt_tokens,
1352
  "prompt_eval_duration": prompt_eval_time,
1353
  "eval_count": completion_tokens,
1354
- "eval_duration": eval_time
1355
  }
1356
  yield f"{json.dumps(data, ensure_ascii=False)}\n"
1357
  return
1358
-
1359
  except Exception as e:
1360
  logging.error(f"Error in stream_generator: {str(e)}")
1361
  raise
1362
-
1363
  return StreamingResponse(
1364
  stream_generator(),
1365
  media_type="application/x-ndjson",
@@ -1375,20 +1371,18 @@ def create_app(args):
1375
  else:
1376
  first_chunk_time = time.time_ns()
1377
  response_text = await rag.llm_model_func(
1378
- query,
1379
- stream=False,
1380
- **rag.llm_model_kwargs
1381
  )
1382
  last_chunk_time = time.time_ns()
1383
-
1384
  if not response_text:
1385
  response_text = "No response generated"
1386
-
1387
  completion_tokens = estimate_tokens(str(response_text))
1388
  total_time = last_chunk_time - start_time
1389
  prompt_eval_time = first_chunk_time - start_time
1390
  eval_time = last_chunk_time - first_chunk_time
1391
-
1392
  return {
1393
  "model": LIGHTRAG_MODEL,
1394
  "created_at": LIGHTRAG_CREATED_AT,
@@ -1399,7 +1393,7 @@ def create_app(args):
1399
  "prompt_eval_count": prompt_tokens,
1400
  "prompt_eval_duration": prompt_eval_time,
1401
  "eval_count": completion_tokens,
1402
- "eval_duration": eval_time
1403
  }
1404
  except Exception as e:
1405
  trace_exception(e)
@@ -1417,16 +1411,12 @@ def create_app(args):
1417
  # Get the last message as query
1418
  query = messages[-1].content
1419
 
1420
- # 解析查询模式
1421
  cleaned_query, mode = parse_query_mode(query)
1422
 
1423
- # 开始计时
1424
  start_time = time.time_ns()
1425
-
1426
- # 计算输入token数量
1427
  prompt_tokens = estimate_tokens(cleaned_query)
1428
 
1429
- # 调用RAG进行查询
1430
  query_param = QueryParam(
1431
  mode=mode, stream=request.stream, only_need_context=False
1432
  )
@@ -1537,25 +1527,21 @@ def create_app(args):
1537
  )
1538
  else:
1539
  first_chunk_time = time.time_ns()
1540
-
1541
- # 判断是否包含特定字符串,使用正则表达式进行匹配
1542
- logging.info(f"Cleaned query content: {cleaned_query}")
1543
- match_result = re.search(r'\n<chat_history>\nUSER:', cleaned_query, re.MULTILINE)
1544
- logging.info(f"Regex match result: {bool(match_result)}")
1545
-
1546
- if match_result:
1547
 
 
 
 
 
 
1548
  if request.system:
1549
  rag.llm_model_kwargs["system_prompt"] = request.system
1550
 
1551
  response_text = await rag.llm_model_func(
1552
- cleaned_query,
1553
- stream=False,
1554
- **rag.llm_model_kwargs
1555
  )
1556
  else:
1557
  response_text = await rag.aquery(cleaned_query, param=query_param)
1558
-
1559
  last_chunk_time = time.time_ns()
1560
 
1561
  if not response_text:
 
476
  message: OllamaMessage
477
  done: bool
478
 
479
+
480
  class OllamaGenerateRequest(BaseModel):
481
  model: str = LIGHTRAG_MODEL
482
  prompt: str
 
484
  stream: bool = False
485
  options: Optional[Dict[str, Any]] = None
486
 
487
+
488
  class OllamaGenerateResponse(BaseModel):
489
  model: str
490
  created_at: str
 
492
  done: bool
493
  context: Optional[List[int]]
494
  total_duration: Optional[int]
495
+ load_duration: Optional[int]
496
  prompt_eval_count: Optional[int]
497
  prompt_eval_duration: Optional[int]
498
  eval_count: Optional[int]
499
  eval_duration: Optional[int]
500
 
501
+
502
  class OllamaVersionResponse(BaseModel):
503
  version: str
504
 
 
1265
  """Handle generate completion requests"""
1266
  try:
1267
  query = request.prompt
 
 
1268
  start_time = time.time_ns()
 
 
1269
  prompt_tokens = estimate_tokens(query)
1270
+
 
1271
  if request.system:
1272
  rag.llm_model_kwargs["system_prompt"] = request.system
1273
+
1274
  if request.stream:
1275
  from fastapi.responses import StreamingResponse
1276
+
1277
  response = await rag.llm_model_func(
1278
+ query, stream=True, **rag.llm_model_kwargs
 
 
1279
  )
1280
+
1281
  async def stream_generator():
1282
  try:
1283
  first_chunk_time = None
1284
  last_chunk_time = None
1285
  total_response = ""
1286
+
1287
+ # Ensure response is an async generator
1288
  if isinstance(response, str):
1289
+ # If it's a string, send in two parts
1290
  first_chunk_time = time.time_ns()
1291
  last_chunk_time = first_chunk_time
1292
  total_response = response
1293
+
1294
  data = {
1295
  "model": LIGHTRAG_MODEL,
1296
  "created_at": LIGHTRAG_CREATED_AT,
1297
  "response": response,
1298
+ "done": False,
1299
  }
1300
  yield f"{json.dumps(data, ensure_ascii=False)}\n"
1301
+
1302
  completion_tokens = estimate_tokens(total_response)
1303
  total_time = last_chunk_time - start_time
1304
  prompt_eval_time = first_chunk_time - start_time
1305
  eval_time = last_chunk_time - first_chunk_time
1306
+
1307
  data = {
1308
  "model": LIGHTRAG_MODEL,
1309
  "created_at": LIGHTRAG_CREATED_AT,
 
1313
  "prompt_eval_count": prompt_tokens,
1314
  "prompt_eval_duration": prompt_eval_time,
1315
  "eval_count": completion_tokens,
1316
+ "eval_duration": eval_time,
1317
  }
1318
  yield f"{json.dumps(data, ensure_ascii=False)}\n"
1319
  else:
 
1321
  if chunk:
1322
  if first_chunk_time is None:
1323
  first_chunk_time = time.time_ns()
1324
+
1325
  last_chunk_time = time.time_ns()
1326
+
1327
  total_response += chunk
1328
  data = {
1329
  "model": LIGHTRAG_MODEL,
1330
  "created_at": LIGHTRAG_CREATED_AT,
1331
  "response": chunk,
1332
+ "done": False,
1333
  }
1334
  yield f"{json.dumps(data, ensure_ascii=False)}\n"
1335
+
1336
  completion_tokens = estimate_tokens(total_response)
1337
  total_time = last_chunk_time - start_time
1338
  prompt_eval_time = first_chunk_time - start_time
1339
  eval_time = last_chunk_time - first_chunk_time
1340
+
1341
  data = {
1342
  "model": LIGHTRAG_MODEL,
1343
  "created_at": LIGHTRAG_CREATED_AT,
 
1347
  "prompt_eval_count": prompt_tokens,
1348
  "prompt_eval_duration": prompt_eval_time,
1349
  "eval_count": completion_tokens,
1350
+ "eval_duration": eval_time,
1351
  }
1352
  yield f"{json.dumps(data, ensure_ascii=False)}\n"
1353
  return
1354
+
1355
  except Exception as e:
1356
  logging.error(f"Error in stream_generator: {str(e)}")
1357
  raise
1358
+
1359
  return StreamingResponse(
1360
  stream_generator(),
1361
  media_type="application/x-ndjson",
 
1371
  else:
1372
  first_chunk_time = time.time_ns()
1373
  response_text = await rag.llm_model_func(
1374
+ query, stream=False, **rag.llm_model_kwargs
 
 
1375
  )
1376
  last_chunk_time = time.time_ns()
1377
+
1378
  if not response_text:
1379
  response_text = "No response generated"
1380
+
1381
  completion_tokens = estimate_tokens(str(response_text))
1382
  total_time = last_chunk_time - start_time
1383
  prompt_eval_time = first_chunk_time - start_time
1384
  eval_time = last_chunk_time - first_chunk_time
1385
+
1386
  return {
1387
  "model": LIGHTRAG_MODEL,
1388
  "created_at": LIGHTRAG_CREATED_AT,
 
1393
  "prompt_eval_count": prompt_tokens,
1394
  "prompt_eval_duration": prompt_eval_time,
1395
  "eval_count": completion_tokens,
1396
+ "eval_duration": eval_time,
1397
  }
1398
  except Exception as e:
1399
  trace_exception(e)
 
1411
  # Get the last message as query
1412
  query = messages[-1].content
1413
 
1414
+ # Check for query prefix
1415
  cleaned_query, mode = parse_query_mode(query)
1416
 
 
1417
  start_time = time.time_ns()
 
 
1418
  prompt_tokens = estimate_tokens(cleaned_query)
1419
 
 
1420
  query_param = QueryParam(
1421
  mode=mode, stream=request.stream, only_need_context=False
1422
  )
 
1527
  )
1528
  else:
1529
  first_chunk_time = time.time_ns()
 
 
 
 
 
 
 
1530
 
1531
+ # Determine if the request is from Open WebUI's session title and session keyword generation task
1532
+ match_result = re.search(
1533
+ r"\n<chat_history>\nUSER:", cleaned_query, re.MULTILINE
1534
+ )
1535
+ if match_result:
1536
  if request.system:
1537
  rag.llm_model_kwargs["system_prompt"] = request.system
1538
 
1539
  response_text = await rag.llm_model_func(
1540
+ cleaned_query, stream=False, **rag.llm_model_kwargs
 
 
1541
  )
1542
  else:
1543
  response_text = await rag.aquery(cleaned_query, param=query_param)
1544
+
1545
  last_chunk_time = time.time_ns()
1546
 
1547
  if not response_text:
test_lightrag_ollama_chat.py CHANGED
@@ -110,7 +110,7 @@ DEFAULT_CONFIG = {
110
  },
111
  "test_cases": {
112
  "basic": {"query": "唐僧有几个徒弟"},
113
- "generate": {"query": "电视剧西游记导演是谁"}
114
  },
115
  }
116
 
@@ -205,12 +205,13 @@ def create_chat_request_data(
205
  "stream": stream,
206
  }
207
 
 
208
  def create_generate_request_data(
209
- prompt: str,
210
  system: str = None,
211
- stream: bool = False,
212
  model: str = None,
213
- options: Dict[str, Any] = None
214
  ) -> Dict[str, Any]:
215
  """Create generate request data
216
  Args:
@@ -225,7 +226,7 @@ def create_generate_request_data(
225
  data = {
226
  "model": model or CONFIG["server"]["model"],
227
  "prompt": prompt,
228
- "stream": stream
229
  }
230
  if system:
231
  data["system"] = system
@@ -258,7 +259,9 @@ def run_test(func: Callable, name: str) -> None:
258
  def test_non_stream_chat() -> None:
259
  """Test non-streaming call to /api/chat endpoint"""
260
  url = get_base_url()
261
- data = create_chat_request_data(CONFIG["test_cases"]["basic"]["query"], stream=False)
 
 
262
 
263
  # Send request
264
  response = make_request(url, data)
@@ -487,8 +490,7 @@ def test_non_stream_generate() -> None:
487
  """Test non-streaming call to /api/generate endpoint"""
488
  url = get_base_url("generate")
489
  data = create_generate_request_data(
490
- CONFIG["test_cases"]["generate"]["query"],
491
- stream=False
492
  )
493
 
494
  # Send request
@@ -504,17 +506,17 @@ def test_non_stream_generate() -> None:
504
  {
505
  "model": response_json["model"],
506
  "response": response_json["response"],
507
- "done": response_json["done"]
508
  },
509
- "Response content"
510
  )
511
 
 
512
  def test_stream_generate() -> None:
513
  """Test streaming call to /api/generate endpoint"""
514
  url = get_base_url("generate")
515
  data = create_generate_request_data(
516
- CONFIG["test_cases"]["generate"]["query"],
517
- stream=True
518
  )
519
 
520
  # Send request and get streaming response
@@ -530,13 +532,17 @@ def test_stream_generate() -> None:
530
  # Decode and parse JSON
531
  data = json.loads(line.decode("utf-8"))
532
  if data.get("done", True): # If it's the completion marker
533
- if "total_duration" in data: # Final performance statistics message
 
 
534
  break
535
  else: # Normal content message
536
  content = data.get("response", "")
537
  if content: # Only collect non-empty content
538
  output_buffer.append(content)
539
- print(content, end="", flush=True) # Print content in real-time
 
 
540
  except json.JSONDecodeError:
541
  print("Error decoding JSON from response line")
542
  finally:
@@ -545,13 +551,14 @@ def test_stream_generate() -> None:
545
  # Print a newline
546
  print()
547
 
 
548
  def test_generate_with_system() -> None:
549
  """Test generate with system prompt"""
550
  url = get_base_url("generate")
551
  data = create_generate_request_data(
552
  CONFIG["test_cases"]["generate"]["query"],
553
  system="你是一个知识渊博的助手",
554
- stream=False
555
  )
556
 
557
  # Send request
@@ -567,15 +574,16 @@ def test_generate_with_system() -> None:
567
  {
568
  "model": response_json["model"],
569
  "response": response_json["response"],
570
- "done": response_json["done"]
571
  },
572
- "Response content"
573
  )
574
 
 
575
  def test_generate_error_handling() -> None:
576
  """Test error handling for generate endpoint"""
577
  url = get_base_url("generate")
578
-
579
  # Test empty prompt
580
  if OutputControl.is_verbose():
581
  print("\n=== Testing empty prompt ===")
@@ -583,14 +591,14 @@ def test_generate_error_handling() -> None:
583
  response = make_request(url, data)
584
  print(f"Status code: {response.status_code}")
585
  print_json_response(response.json(), "Error message")
586
-
587
  # Test invalid options
588
  if OutputControl.is_verbose():
589
  print("\n=== Testing invalid options ===")
590
  data = create_generate_request_data(
591
  CONFIG["test_cases"]["basic"]["query"],
592
  options={"invalid_option": "value"},
593
- stream=False
594
  )
595
  response = make_request(url, data)
596
  print(f"Status code: {response.status_code}")
@@ -602,12 +610,12 @@ def test_generate_concurrent() -> None:
602
  import asyncio
603
  import aiohttp
604
  from contextlib import asynccontextmanager
605
-
606
  @asynccontextmanager
607
  async def get_session():
608
  async with aiohttp.ClientSession() as session:
609
  yield session
610
-
611
  async def make_request(session, prompt: str):
612
  url = get_base_url("generate")
613
  data = create_generate_request_data(prompt, stream=False)
@@ -616,32 +624,27 @@ def test_generate_concurrent() -> None:
616
  return await response.json()
617
  except Exception as e:
618
  return {"error": str(e)}
619
-
620
  async def run_concurrent_requests():
621
- prompts = [
622
- "第一个问题",
623
- "第二个问题",
624
- "第三个问题",
625
- "第四个问题",
626
- "第五个问题"
627
- ]
628
-
629
  async with get_session() as session:
630
  tasks = [make_request(session, prompt) for prompt in prompts]
631
  results = await asyncio.gather(*tasks)
632
  return results
633
-
634
  if OutputControl.is_verbose():
635
  print("\n=== Testing concurrent generate requests ===")
636
-
637
  # Run concurrent requests
638
  results = asyncio.run(run_concurrent_requests())
639
-
640
  # Print results
641
  for i, result in enumerate(results, 1):
642
  print(f"\nRequest {i} result:")
643
  print_json_response(result)
644
 
 
645
  def get_test_cases() -> Dict[str, Callable]:
646
  """Get all available test cases
647
  Returns:
@@ -657,7 +660,7 @@ def get_test_cases() -> Dict[str, Callable]:
657
  "stream_generate": test_stream_generate,
658
  "generate_with_system": test_generate_with_system,
659
  "generate_errors": test_generate_error_handling,
660
- "generate_concurrent": test_generate_concurrent
661
  }
662
 
663
 
 
110
  },
111
  "test_cases": {
112
  "basic": {"query": "唐僧有几个徒弟"},
113
+ "generate": {"query": "电视剧西游记导演是谁"},
114
  },
115
  }
116
 
 
205
  "stream": stream,
206
  }
207
 
208
+
209
  def create_generate_request_data(
210
+ prompt: str,
211
  system: str = None,
212
+ stream: bool = False,
213
  model: str = None,
214
+ options: Dict[str, Any] = None,
215
  ) -> Dict[str, Any]:
216
  """Create generate request data
217
  Args:
 
226
  data = {
227
  "model": model or CONFIG["server"]["model"],
228
  "prompt": prompt,
229
+ "stream": stream,
230
  }
231
  if system:
232
  data["system"] = system
 
259
  def test_non_stream_chat() -> None:
260
  """Test non-streaming call to /api/chat endpoint"""
261
  url = get_base_url()
262
+ data = create_chat_request_data(
263
+ CONFIG["test_cases"]["basic"]["query"], stream=False
264
+ )
265
 
266
  # Send request
267
  response = make_request(url, data)
 
490
  """Test non-streaming call to /api/generate endpoint"""
491
  url = get_base_url("generate")
492
  data = create_generate_request_data(
493
+ CONFIG["test_cases"]["generate"]["query"], stream=False
 
494
  )
495
 
496
  # Send request
 
506
  {
507
  "model": response_json["model"],
508
  "response": response_json["response"],
509
+ "done": response_json["done"],
510
  },
511
+ "Response content",
512
  )
513
 
514
+
515
  def test_stream_generate() -> None:
516
  """Test streaming call to /api/generate endpoint"""
517
  url = get_base_url("generate")
518
  data = create_generate_request_data(
519
+ CONFIG["test_cases"]["generate"]["query"], stream=True
 
520
  )
521
 
522
  # Send request and get streaming response
 
532
  # Decode and parse JSON
533
  data = json.loads(line.decode("utf-8"))
534
  if data.get("done", True): # If it's the completion marker
535
+ if (
536
+ "total_duration" in data
537
+ ): # Final performance statistics message
538
  break
539
  else: # Normal content message
540
  content = data.get("response", "")
541
  if content: # Only collect non-empty content
542
  output_buffer.append(content)
543
+ print(
544
+ content, end="", flush=True
545
+ ) # Print content in real-time
546
  except json.JSONDecodeError:
547
  print("Error decoding JSON from response line")
548
  finally:
 
551
  # Print a newline
552
  print()
553
 
554
+
555
  def test_generate_with_system() -> None:
556
  """Test generate with system prompt"""
557
  url = get_base_url("generate")
558
  data = create_generate_request_data(
559
  CONFIG["test_cases"]["generate"]["query"],
560
  system="你是一个知识渊博的助手",
561
+ stream=False,
562
  )
563
 
564
  # Send request
 
574
  {
575
  "model": response_json["model"],
576
  "response": response_json["response"],
577
+ "done": response_json["done"],
578
  },
579
+ "Response content",
580
  )
581
 
582
+
583
  def test_generate_error_handling() -> None:
584
  """Test error handling for generate endpoint"""
585
  url = get_base_url("generate")
586
+
587
  # Test empty prompt
588
  if OutputControl.is_verbose():
589
  print("\n=== Testing empty prompt ===")
 
591
  response = make_request(url, data)
592
  print(f"Status code: {response.status_code}")
593
  print_json_response(response.json(), "Error message")
594
+
595
  # Test invalid options
596
  if OutputControl.is_verbose():
597
  print("\n=== Testing invalid options ===")
598
  data = create_generate_request_data(
599
  CONFIG["test_cases"]["basic"]["query"],
600
  options={"invalid_option": "value"},
601
+ stream=False,
602
  )
603
  response = make_request(url, data)
604
  print(f"Status code: {response.status_code}")
 
610
  import asyncio
611
  import aiohttp
612
  from contextlib import asynccontextmanager
613
+
614
  @asynccontextmanager
615
  async def get_session():
616
  async with aiohttp.ClientSession() as session:
617
  yield session
618
+
619
  async def make_request(session, prompt: str):
620
  url = get_base_url("generate")
621
  data = create_generate_request_data(prompt, stream=False)
 
624
  return await response.json()
625
  except Exception as e:
626
  return {"error": str(e)}
627
+
628
  async def run_concurrent_requests():
629
+ prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
630
+
 
 
 
 
 
 
631
  async with get_session() as session:
632
  tasks = [make_request(session, prompt) for prompt in prompts]
633
  results = await asyncio.gather(*tasks)
634
  return results
635
+
636
  if OutputControl.is_verbose():
637
  print("\n=== Testing concurrent generate requests ===")
638
+
639
  # Run concurrent requests
640
  results = asyncio.run(run_concurrent_requests())
641
+
642
  # Print results
643
  for i, result in enumerate(results, 1):
644
  print(f"\nRequest {i} result:")
645
  print_json_response(result)
646
 
647
+
648
  def get_test_cases() -> Dict[str, Callable]:
649
  """Get all available test cases
650
  Returns:
 
660
  "stream_generate": test_stream_generate,
661
  "generate_with_system": test_generate_with_system,
662
  "generate_errors": test_generate_error_handling,
663
+ "generate_concurrent": test_generate_concurrent,
664
  }
665
 
666