yangdx commited on
Commit
bca22a9
·
2 Parent(s): 8f6d644 b59560d

Merge tag 'time-temp' into improve-ollama-api-streaming

Browse files
lightrag/api/ollama_api.py CHANGED
@@ -205,14 +205,14 @@ class OllamaAPI:
205
  async def stream_generator():
206
  try:
207
  first_chunk_time = None
208
- last_chunk_time = None
209
  total_response = ""
210
 
211
  # Ensure response is an async generator
212
  if isinstance(response, str):
213
  # If it's a string, send in two parts
214
- first_chunk_time = time.time_ns()
215
- last_chunk_time = first_chunk_time
216
  total_response = response
217
 
218
  data = {
@@ -241,22 +241,50 @@ class OllamaAPI:
241
  }
242
  yield f"{json.dumps(data, ensure_ascii=False)}\n"
243
  else:
244
- async for chunk in response:
245
- if chunk:
246
- if first_chunk_time is None:
247
- first_chunk_time = time.time_ns()
248
-
249
- last_chunk_time = time.time_ns()
250
-
251
- total_response += chunk
252
- data = {
253
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
254
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
255
- "response": chunk,
256
- "done": False,
257
- }
258
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  completion_tokens = estimate_tokens(total_response)
261
  total_time = last_chunk_time - start_time
262
  prompt_eval_time = first_chunk_time - start_time
@@ -381,16 +409,16 @@ class OllamaAPI:
381
  )
382
 
383
  async def stream_generator():
384
- first_chunk_time = None
385
- last_chunk_time = None
386
- total_response = ""
387
-
388
  try:
 
 
 
 
389
  # Ensure response is an async generator
390
  if isinstance(response, str):
391
  # If it's a string, send in two parts
392
- first_chunk_time = time.time_ns()
393
- last_chunk_time = first_chunk_time
394
  total_response = response
395
 
396
  data = {
@@ -474,45 +502,29 @@ class OllamaAPI:
474
  yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
475
  return
476
 
477
- if last_chunk_time is not None:
478
- completion_tokens = estimate_tokens(total_response)
479
- total_time = last_chunk_time - start_time
480
- prompt_eval_time = first_chunk_time - start_time
481
- eval_time = last_chunk_time - first_chunk_time
 
482
 
483
- data = {
484
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
485
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
486
- "done": True,
487
- "total_duration": total_time,
488
- "load_duration": 0,
489
- "prompt_eval_count": prompt_tokens,
490
- "prompt_eval_duration": prompt_eval_time,
491
- "eval_count": completion_tokens,
492
- "eval_duration": eval_time,
493
- }
494
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
495
 
496
  except Exception as e:
497
- error_msg = f"Error in stream_generator: {str(e)}"
498
- logging.error(error_msg)
499
-
500
- # Send error message to client
501
- error_data = {
502
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
503
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
504
- "error": {"code": "STREAM_ERROR", "message": error_msg},
505
- }
506
- yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
507
-
508
- # Ensure sending end marker
509
- final_data = {
510
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
511
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
512
- "done": True,
513
- }
514
- yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
515
- return
516
 
517
  return StreamingResponse(
518
  stream_generator(),
 
205
  async def stream_generator():
206
  try:
207
  first_chunk_time = None
208
+ last_chunk_time = time.time_ns()
209
  total_response = ""
210
 
211
  # Ensure response is an async generator
212
  if isinstance(response, str):
213
  # If it's a string, send in two parts
214
+ first_chunk_time = start_time
215
+ last_chunk_time = time.time_ns()
216
  total_response = response
217
 
218
  data = {
 
241
  }
242
  yield f"{json.dumps(data, ensure_ascii=False)}\n"
243
  else:
244
+ try:
245
+ async for chunk in response:
246
+ if chunk:
247
+ if first_chunk_time is None:
248
+ first_chunk_time = time.time_ns()
 
 
 
 
 
 
 
 
 
 
249
 
250
+ last_chunk_time = time.time_ns()
251
+
252
+ total_response += chunk
253
+ data = {
254
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
255
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
256
+ "response": chunk,
257
+ "done": False,
258
+ }
259
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
260
+ except (asyncio.CancelledError, Exception) as e:
261
+ error_msg = str(e)
262
+ if isinstance(e, asyncio.CancelledError):
263
+ error_msg = "Stream was cancelled by server"
264
+ else:
265
+ error_msg = f"Provider error: {error_msg}"
266
+
267
+ logging.error(f"Stream error: {error_msg}")
268
+
269
+ # Send error message to client
270
+ error_data = {
271
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
272
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
273
+ "response": f"\n\nError: {error_msg}",
274
+ "done": False,
275
+ }
276
+ yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
277
+
278
+ # Send final message to close the stream
279
+ final_data = {
280
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
281
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
282
+ "done": True,
283
+ }
284
+ yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
285
+ return
286
+ if first_chunk_time is None:
287
+ first_chunk_time = start_time
288
  completion_tokens = estimate_tokens(total_response)
289
  total_time = last_chunk_time - start_time
290
  prompt_eval_time = first_chunk_time - start_time
 
409
  )
410
 
411
  async def stream_generator():
 
 
 
 
412
  try:
413
+ first_chunk_time = None
414
+ last_chunk_time = time.time_ns()
415
+ total_response = ""
416
+
417
  # Ensure response is an async generator
418
  if isinstance(response, str):
419
  # If it's a string, send in two parts
420
+ first_chunk_time = start_time
421
+ last_chunk_time = time.time_ns()
422
  total_response = response
423
 
424
  data = {
 
502
  yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
503
  return
504
 
505
+ if first_chunk_time is None:
506
+ first_chunk_time = start_time
507
+ completion_tokens = estimate_tokens(total_response)
508
+ total_time = last_chunk_time - start_time
509
+ prompt_eval_time = first_chunk_time - start_time
510
+ eval_time = last_chunk_time - first_chunk_time
511
 
512
+ data = {
513
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
514
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
515
+ "done": True,
516
+ "total_duration": total_time,
517
+ "load_duration": 0,
518
+ "prompt_eval_count": prompt_tokens,
519
+ "prompt_eval_duration": prompt_eval_time,
520
+ "eval_count": completion_tokens,
521
+ "eval_duration": eval_time,
522
+ }
523
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
524
 
525
  except Exception as e:
526
+ trace_exception(e)
527
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
  return StreamingResponse(
530
  stream_generator(),
test_lightrag_ollama_chat.py CHANGED
@@ -17,6 +17,24 @@ from typing import Dict, Any, Optional, List, Callable
17
  from dataclasses import dataclass, asdict
18
  from datetime import datetime
19
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  DEFAULT_CONFIG = {
22
  "server": {
@@ -634,35 +652,82 @@ def test_generate_concurrent() -> None:
634
  async with aiohttp.ClientSession() as session:
635
  yield session
636
 
637
- async def make_request(session, prompt: str):
638
  url = get_base_url("generate")
639
  data = create_generate_request_data(prompt, stream=False)
640
  try:
641
  async with session.post(url, json=data) as response:
642
  if response.status != 200:
643
- response.raise_for_status()
644
- return await response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
645
  except Exception as e:
646
- return {"error": str(e)}
 
 
 
647
 
648
  async def run_concurrent_requests():
649
  prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
650
 
651
  async with get_session() as session:
652
- tasks = [make_request(session, prompt) for prompt in prompts]
653
- results = await asyncio.gather(*tasks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
654
  return results
655
 
656
  if OutputControl.is_verbose():
657
  print("\n=== Testing concurrent generate requests ===")
658
 
659
  # Run concurrent requests
660
- results = asyncio.run(run_concurrent_requests())
661
-
662
- # Print results
663
- for i, result in enumerate(results, 1):
664
- print(f"\nRequest {i} result:")
665
- print_json_response(result)
 
 
 
666
 
667
 
668
  def get_test_cases() -> Dict[str, Callable]:
 
17
  from dataclasses import dataclass, asdict
18
  from datetime import datetime
19
  from pathlib import Path
20
+ from enum import Enum, auto
21
+
22
+
23
+ class ErrorCode(Enum):
24
+ """Error codes for MCP errors"""
25
+
26
+ InvalidRequest = auto()
27
+ InternalError = auto()
28
+
29
+
30
+ class McpError(Exception):
31
+ """Base exception class for MCP errors"""
32
+
33
+ def __init__(self, code: ErrorCode, message: str):
34
+ self.code = code
35
+ self.message = message
36
+ super().__init__(message)
37
+
38
 
39
  DEFAULT_CONFIG = {
40
  "server": {
 
652
  async with aiohttp.ClientSession() as session:
653
  yield session
654
 
655
+ async def make_request(session, prompt: str, request_id: int):
656
  url = get_base_url("generate")
657
  data = create_generate_request_data(prompt, stream=False)
658
  try:
659
  async with session.post(url, json=data) as response:
660
  if response.status != 200:
661
+ error_msg = (
662
+ f"Request {request_id} failed with status {response.status}"
663
+ )
664
+ if OutputControl.is_verbose():
665
+ print(f"\n{error_msg}")
666
+ raise McpError(ErrorCode.InternalError, error_msg)
667
+ result = await response.json()
668
+ if "error" in result:
669
+ error_msg = (
670
+ f"Request {request_id} returned error: {result['error']}"
671
+ )
672
+ if OutputControl.is_verbose():
673
+ print(f"\n{error_msg}")
674
+ raise McpError(ErrorCode.InternalError, error_msg)
675
+ return result
676
  except Exception as e:
677
+ error_msg = f"Request {request_id} failed: {str(e)}"
678
+ if OutputControl.is_verbose():
679
+ print(f"\n{error_msg}")
680
+ raise McpError(ErrorCode.InternalError, error_msg)
681
 
682
  async def run_concurrent_requests():
683
  prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
684
 
685
  async with get_session() as session:
686
+ tasks = [
687
+ make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts)
688
+ ]
689
+ results = await asyncio.gather(*tasks, return_exceptions=True)
690
+
691
+ # 收集成功和失败的结果
692
+ success_results = []
693
+ error_messages = []
694
+
695
+ for i, result in enumerate(results):
696
+ if isinstance(result, Exception):
697
+ error_messages.append(f"Request {i+1} failed: {str(result)}")
698
+ else:
699
+ success_results.append((i + 1, result))
700
+
701
+ # 如果有任何错误,在打印完所有结果后抛出异常
702
+ if error_messages:
703
+ # 先打印成功的结果
704
+ for req_id, result in success_results:
705
+ if OutputControl.is_verbose():
706
+ print(f"\nRequest {req_id} succeeded:")
707
+ print_json_response(result)
708
+
709
+ # 打印所有错误信息
710
+ error_summary = "\n".join(error_messages)
711
+ raise McpError(
712
+ ErrorCode.InternalError,
713
+ f"Some concurrent requests failed:\n{error_summary}",
714
+ )
715
+
716
  return results
717
 
718
  if OutputControl.is_verbose():
719
  print("\n=== Testing concurrent generate requests ===")
720
 
721
  # Run concurrent requests
722
+ try:
723
+ results = asyncio.run(run_concurrent_requests())
724
+ # 如果没有异常,打印所有成功的结果
725
+ for i, result in enumerate(results, 1):
726
+ print(f"\nRequest {i} result:")
727
+ print_json_response(result)
728
+ except McpError:
729
+ # 错误信息已经在之前打印过了,这里直接抛出
730
+ raise
731
 
732
 
733
  def get_test_cases() -> Dict[str, Callable]: