zrguo commited on
Commit
f10a3eb
·
unverified ·
2 Parent(s): cf0bd5c e58ea27

Merge pull request #723 from danielaskdd/improve-ollama-api-streaming

Browse files
lightrag/api/ollama_api.py CHANGED
@@ -205,14 +205,14 @@ class OllamaAPI:
205
  async def stream_generator():
206
  try:
207
  first_chunk_time = None
208
- last_chunk_time = None
209
  total_response = ""
210
 
211
  # Ensure response is an async generator
212
  if isinstance(response, str):
213
  # If it's a string, send in two parts
214
- first_chunk_time = time.time_ns()
215
- last_chunk_time = first_chunk_time
216
  total_response = response
217
 
218
  data = {
@@ -241,22 +241,50 @@ class OllamaAPI:
241
  }
242
  yield f"{json.dumps(data, ensure_ascii=False)}\n"
243
  else:
244
- async for chunk in response:
245
- if chunk:
246
- if first_chunk_time is None:
247
- first_chunk_time = time.time_ns()
248
-
249
- last_chunk_time = time.time_ns()
250
-
251
- total_response += chunk
252
- data = {
253
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
254
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
255
- "response": chunk,
256
- "done": False,
257
- }
258
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  completion_tokens = estimate_tokens(total_response)
261
  total_time = last_chunk_time - start_time
262
  prompt_eval_time = first_chunk_time - start_time
@@ -381,16 +409,16 @@ class OllamaAPI:
381
  )
382
 
383
  async def stream_generator():
384
- first_chunk_time = None
385
- last_chunk_time = None
386
- total_response = ""
387
-
388
  try:
 
 
 
 
389
  # Ensure response is an async generator
390
  if isinstance(response, str):
391
  # If it's a string, send in two parts
392
- first_chunk_time = time.time_ns()
393
- last_chunk_time = first_chunk_time
394
  total_response = response
395
 
396
  data = {
@@ -474,45 +502,29 @@ class OllamaAPI:
474
  yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
475
  return
476
 
477
- if last_chunk_time is not None:
478
- completion_tokens = estimate_tokens(total_response)
479
- total_time = last_chunk_time - start_time
480
- prompt_eval_time = first_chunk_time - start_time
481
- eval_time = last_chunk_time - first_chunk_time
 
482
 
483
- data = {
484
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
485
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
486
- "done": True,
487
- "total_duration": total_time,
488
- "load_duration": 0,
489
- "prompt_eval_count": prompt_tokens,
490
- "prompt_eval_duration": prompt_eval_time,
491
- "eval_count": completion_tokens,
492
- "eval_duration": eval_time,
493
- }
494
- yield f"{json.dumps(data, ensure_ascii=False)}\n"
495
 
496
  except Exception as e:
497
- error_msg = f"Error in stream_generator: {str(e)}"
498
- logging.error(error_msg)
499
-
500
- # Send error message to client
501
- error_data = {
502
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
503
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
504
- "error": {"code": "STREAM_ERROR", "message": error_msg},
505
- }
506
- yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
507
-
508
- # Ensure sending end marker
509
- final_data = {
510
- "model": self.ollama_server_infos.LIGHTRAG_MODEL,
511
- "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
512
- "done": True,
513
- }
514
- yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
515
- return
516
 
517
  return StreamingResponse(
518
  stream_generator(),
 
205
  async def stream_generator():
206
  try:
207
  first_chunk_time = None
208
+ last_chunk_time = time.time_ns()
209
  total_response = ""
210
 
211
  # Ensure response is an async generator
212
  if isinstance(response, str):
213
  # If it's a string, send in two parts
214
+ first_chunk_time = start_time
215
+ last_chunk_time = time.time_ns()
216
  total_response = response
217
 
218
  data = {
 
241
  }
242
  yield f"{json.dumps(data, ensure_ascii=False)}\n"
243
  else:
244
+ try:
245
+ async for chunk in response:
246
+ if chunk:
247
+ if first_chunk_time is None:
248
+ first_chunk_time = time.time_ns()
 
 
 
 
 
 
 
 
 
 
249
 
250
+ last_chunk_time = time.time_ns()
251
+
252
+ total_response += chunk
253
+ data = {
254
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
255
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
256
+ "response": chunk,
257
+ "done": False,
258
+ }
259
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
260
+ except (asyncio.CancelledError, Exception) as e:
261
+ error_msg = str(e)
262
+ if isinstance(e, asyncio.CancelledError):
263
+ error_msg = "Stream was cancelled by server"
264
+ else:
265
+ error_msg = f"Provider error: {error_msg}"
266
+
267
+ logging.error(f"Stream error: {error_msg}")
268
+
269
+ # Send error message to client
270
+ error_data = {
271
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
272
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
273
+ "response": f"\n\nError: {error_msg}",
274
+ "done": False,
275
+ }
276
+ yield f"{json.dumps(error_data, ensure_ascii=False)}\n"
277
+
278
+ # Send final message to close the stream
279
+ final_data = {
280
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
281
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
282
+ "done": True,
283
+ }
284
+ yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
285
+ return
286
+ if first_chunk_time is None:
287
+ first_chunk_time = start_time
288
  completion_tokens = estimate_tokens(total_response)
289
  total_time = last_chunk_time - start_time
290
  prompt_eval_time = first_chunk_time - start_time
 
409
  )
410
 
411
  async def stream_generator():
 
 
 
 
412
  try:
413
+ first_chunk_time = None
414
+ last_chunk_time = time.time_ns()
415
+ total_response = ""
416
+
417
  # Ensure response is an async generator
418
  if isinstance(response, str):
419
  # If it's a string, send in two parts
420
+ first_chunk_time = start_time
421
+ last_chunk_time = time.time_ns()
422
  total_response = response
423
 
424
  data = {
 
502
  yield f"{json.dumps(final_data, ensure_ascii=False)}\n"
503
  return
504
 
505
+ if first_chunk_time is None:
506
+ first_chunk_time = start_time
507
+ completion_tokens = estimate_tokens(total_response)
508
+ total_time = last_chunk_time - start_time
509
+ prompt_eval_time = first_chunk_time - start_time
510
+ eval_time = last_chunk_time - first_chunk_time
511
 
512
+ data = {
513
+ "model": self.ollama_server_infos.LIGHTRAG_MODEL,
514
+ "created_at": self.ollama_server_infos.LIGHTRAG_CREATED_AT,
515
+ "done": True,
516
+ "total_duration": total_time,
517
+ "load_duration": 0,
518
+ "prompt_eval_count": prompt_tokens,
519
+ "prompt_eval_duration": prompt_eval_time,
520
+ "eval_count": completion_tokens,
521
+ "eval_duration": eval_time,
522
+ }
523
+ yield f"{json.dumps(data, ensure_ascii=False)}\n"
524
 
525
  except Exception as e:
526
+ trace_exception(e)
527
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
 
529
  return StreamingResponse(
530
  stream_generator(),
lightrag/llm/ollama.py CHANGED
@@ -66,6 +66,7 @@ from lightrag.exceptions import (
66
  RateLimitError,
67
  APITimeoutError,
68
  )
 
69
  import numpy as np
70
  from typing import Union
71
 
@@ -91,11 +92,12 @@ async def ollama_model_if_cache(
91
  timeout = kwargs.pop("timeout", None)
92
  kwargs.pop("hashing_kv", None)
93
  api_key = kwargs.pop("api_key", None)
94
- headers = (
95
- {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
96
- if api_key
97
- else {"Content-Type": "application/json"}
98
- )
 
99
  ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
100
  messages = []
101
  if system_prompt:
@@ -147,11 +149,12 @@ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarra
147
 
148
  async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
149
  api_key = kwargs.pop("api_key", None)
150
- headers = (
151
- {"Content-Type": "application/json", "Authorization": api_key}
152
- if api_key
153
- else {"Content-Type": "application/json"}
154
- )
 
155
  kwargs["headers"] = headers
156
  ollama_client = ollama.Client(**kwargs)
157
  data = ollama_client.embed(model=embed_model, input=texts)
 
66
  RateLimitError,
67
  APITimeoutError,
68
  )
69
+ from lightrag.api import __api_version__
70
  import numpy as np
71
  from typing import Union
72
 
 
92
  timeout = kwargs.pop("timeout", None)
93
  kwargs.pop("hashing_kv", None)
94
  api_key = kwargs.pop("api_key", None)
95
+ headers = {
96
+ "Content-Type": "application/json",
97
+ "User-Agent": f"LightRAG/{__api_version__}",
98
+ }
99
+ if api_key:
100
+ headers["Authorization"] = f"Bearer {api_key}"
101
  ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
102
  messages = []
103
  if system_prompt:
 
149
 
150
  async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
151
  api_key = kwargs.pop("api_key", None)
152
+ headers = {
153
+ "Content-Type": "application/json",
154
+ "User-Agent": f"LightRAG/{__api_version__}",
155
+ }
156
+ if api_key:
157
+ headers["Authorization"] = api_key
158
  kwargs["headers"] = headers
159
  ollama_client = ollama.Client(**kwargs)
160
  data = ollama_client.embed(model=embed_model, input=texts)
lightrag/llm/openai.py CHANGED
@@ -73,16 +73,23 @@ from lightrag.utils import (
73
  logger,
74
  )
75
  from lightrag.types import GPTKeywordExtractionFormat
 
76
 
77
  import numpy as np
78
  from typing import Union
79
 
80
 
 
 
 
 
 
 
81
  @retry(
82
  stop=stop_after_attempt(3),
83
  wait=wait_exponential(multiplier=1, min=4, max=10),
84
  retry=retry_if_exception_type(
85
- (RateLimitError, APIConnectionError, APITimeoutError)
86
  ),
87
  )
88
  async def openai_complete_if_cache(
@@ -99,8 +106,14 @@ async def openai_complete_if_cache(
99
  if api_key:
100
  os.environ["OPENAI_API_KEY"] = api_key
101
 
 
 
 
 
102
  openai_async_client = (
103
- AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
 
 
104
  )
105
  kwargs.pop("hashing_kv", None)
106
  kwargs.pop("keyword_extraction", None)
@@ -112,17 +125,35 @@ async def openai_complete_if_cache(
112
 
113
  # 添加日志输出
114
  logger.debug("===== Query Input to LLM =====")
 
 
115
  logger.debug(f"Query: {prompt}")
116
  logger.debug(f"System prompt: {system_prompt}")
117
- logger.debug("Full context:")
118
- if "response_format" in kwargs:
119
- response = await openai_async_client.beta.chat.completions.parse(
120
- model=model, messages=messages, **kwargs
121
- )
122
- else:
123
- response = await openai_async_client.chat.completions.create(
124
- model=model, messages=messages, **kwargs
125
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  if hasattr(response, "__aiter__"):
128
 
@@ -140,8 +171,23 @@ async def openai_complete_if_cache(
140
  raise
141
 
142
  return inner()
 
143
  else:
 
 
 
 
 
 
 
 
 
144
  content = response.choices[0].message.content
 
 
 
 
 
145
  if r"\u" in content:
146
  content = safe_unicode_decode(content.encode("utf-8"))
147
  return content
@@ -251,8 +297,14 @@ async def openai_embed(
251
  if api_key:
252
  os.environ["OPENAI_API_KEY"] = api_key
253
 
 
 
 
 
254
  openai_async_client = (
255
- AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
 
 
256
  )
257
  response = await openai_async_client.embeddings.create(
258
  model=model, input=texts, encoding_format="float"
 
73
  logger,
74
  )
75
  from lightrag.types import GPTKeywordExtractionFormat
76
+ from lightrag.api import __api_version__
77
 
78
  import numpy as np
79
  from typing import Union
80
 
81
 
82
+ class InvalidResponseError(Exception):
83
+ """Custom exception class for triggering retry mechanism"""
84
+
85
+ pass
86
+
87
+
88
  @retry(
89
  stop=stop_after_attempt(3),
90
  wait=wait_exponential(multiplier=1, min=4, max=10),
91
  retry=retry_if_exception_type(
92
+ (RateLimitError, APIConnectionError, APITimeoutError, InvalidResponseError)
93
  ),
94
  )
95
  async def openai_complete_if_cache(
 
106
  if api_key:
107
  os.environ["OPENAI_API_KEY"] = api_key
108
 
109
+ default_headers = {
110
+ "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
111
+ "Content-Type": "application/json",
112
+ }
113
  openai_async_client = (
114
+ AsyncOpenAI(default_headers=default_headers)
115
+ if base_url is None
116
+ else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
117
  )
118
  kwargs.pop("hashing_kv", None)
119
  kwargs.pop("keyword_extraction", None)
 
125
 
126
  # 添加日志输出
127
  logger.debug("===== Query Input to LLM =====")
128
+ logger.debug(f"Model: {model} Base URL: {base_url}")
129
+ logger.debug(f"Additional kwargs: {kwargs}")
130
  logger.debug(f"Query: {prompt}")
131
  logger.debug(f"System prompt: {system_prompt}")
132
+ # logger.debug(f"Messages: {messages}")
133
+
134
+ try:
135
+ if "response_format" in kwargs:
136
+ response = await openai_async_client.beta.chat.completions.parse(
137
+ model=model, messages=messages, **kwargs
138
+ )
139
+ else:
140
+ response = await openai_async_client.chat.completions.create(
141
+ model=model, messages=messages, **kwargs
142
+ )
143
+ except APIConnectionError as e:
144
+ logger.error(f"OpenAI API Connection Error: {str(e)}")
145
+ raise
146
+ except RateLimitError as e:
147
+ logger.error(f"OpenAI API Rate Limit Error: {str(e)}")
148
+ raise
149
+ except APITimeoutError as e:
150
+ logger.error(f"OpenAI API Timeout Error: {str(e)}")
151
+ raise
152
+ except Exception as e:
153
+ logger.error(f"OpenAI API Call Failed: {str(e)}")
154
+ logger.error(f"Model: {model}")
155
+ logger.error(f"Request parameters: {kwargs}")
156
+ raise
157
 
158
  if hasattr(response, "__aiter__"):
159
 
 
171
  raise
172
 
173
  return inner()
174
+
175
  else:
176
+ if (
177
+ not response
178
+ or not response.choices
179
+ or not hasattr(response.choices[0], "message")
180
+ or not hasattr(response.choices[0].message, "content")
181
+ ):
182
+ logger.error("Invalid response from OpenAI API")
183
+ raise InvalidResponseError("Invalid response from OpenAI API")
184
+
185
  content = response.choices[0].message.content
186
+
187
+ if not content or content.strip() == "":
188
+ logger.error("Received empty content from OpenAI API")
189
+ raise InvalidResponseError("Received empty content from OpenAI API")
190
+
191
  if r"\u" in content:
192
  content = safe_unicode_decode(content.encode("utf-8"))
193
  return content
 
297
  if api_key:
298
  os.environ["OPENAI_API_KEY"] = api_key
299
 
300
+ default_headers = {
301
+ "User-Agent": f"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_8) LightRAG/{__api_version__}",
302
+ "Content-Type": "application/json",
303
+ }
304
  openai_async_client = (
305
+ AsyncOpenAI(default_headers=default_headers)
306
+ if base_url is None
307
+ else AsyncOpenAI(base_url=base_url, default_headers=default_headers)
308
  )
309
  response = await openai_async_client.embeddings.create(
310
  model=model, input=texts, encoding_format="float"
test_lightrag_ollama_chat.py CHANGED
@@ -17,14 +17,32 @@ from typing import Dict, Any, Optional, List, Callable
17
  from dataclasses import dataclass, asdict
18
  from datetime import datetime
19
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  DEFAULT_CONFIG = {
22
  "server": {
23
  "host": "localhost",
24
  "port": 9621,
25
  "model": "lightrag:latest",
26
- "timeout": 120,
27
- "max_retries": 3,
28
  "retry_delay": 1,
29
  },
30
  "test_cases": {
@@ -527,14 +545,7 @@ def test_non_stream_generate() -> None:
527
  response_json = response.json()
528
 
529
  # Print response content
530
- print_json_response(
531
- {
532
- "model": response_json["model"],
533
- "response": response_json["response"],
534
- "done": response_json["done"],
535
- },
536
- "Response content",
537
- )
538
 
539
 
540
  def test_stream_generate() -> None:
@@ -641,35 +652,78 @@ def test_generate_concurrent() -> None:
641
  async with aiohttp.ClientSession() as session:
642
  yield session
643
 
644
- async def make_request(session, prompt: str):
645
  url = get_base_url("generate")
646
  data = create_generate_request_data(prompt, stream=False)
647
  try:
648
  async with session.post(url, json=data) as response:
649
  if response.status != 200:
650
- response.raise_for_status()
651
- return await response.json()
 
 
 
 
 
 
 
 
 
 
 
 
 
652
  except Exception as e:
653
- return {"error": str(e)}
 
 
 
654
 
655
  async def run_concurrent_requests():
656
  prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
657
 
658
  async with get_session() as session:
659
- tasks = [make_request(session, prompt) for prompt in prompts]
660
- results = await asyncio.gather(*tasks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
661
  return results
662
 
663
  if OutputControl.is_verbose():
664
  print("\n=== Testing concurrent generate requests ===")
665
 
666
  # Run concurrent requests
667
- results = asyncio.run(run_concurrent_requests())
668
-
669
- # Print results
670
- for i, result in enumerate(results, 1):
671
- print(f"\nRequest {i} result:")
672
- print_json_response(result)
 
 
 
673
 
674
 
675
  def get_test_cases() -> Dict[str, Callable]:
 
17
  from dataclasses import dataclass, asdict
18
  from datetime import datetime
19
  from pathlib import Path
20
+ from enum import Enum, auto
21
+
22
+
23
+ class ErrorCode(Enum):
24
+ """Error codes for MCP errors"""
25
+
26
+ InvalidRequest = auto()
27
+ InternalError = auto()
28
+
29
+
30
+ class McpError(Exception):
31
+ """Base exception class for MCP errors"""
32
+
33
+ def __init__(self, code: ErrorCode, message: str):
34
+ self.code = code
35
+ self.message = message
36
+ super().__init__(message)
37
+
38
 
39
  DEFAULT_CONFIG = {
40
  "server": {
41
  "host": "localhost",
42
  "port": 9621,
43
  "model": "lightrag:latest",
44
+ "timeout": 300,
45
+ "max_retries": 1,
46
  "retry_delay": 1,
47
  },
48
  "test_cases": {
 
545
  response_json = response.json()
546
 
547
  # Print response content
548
+ print(json.dumps(response_json, ensure_ascii=False, indent=2))
 
 
 
 
 
 
 
549
 
550
 
551
  def test_stream_generate() -> None:
 
652
  async with aiohttp.ClientSession() as session:
653
  yield session
654
 
655
+ async def make_request(session, prompt: str, request_id: int):
656
  url = get_base_url("generate")
657
  data = create_generate_request_data(prompt, stream=False)
658
  try:
659
  async with session.post(url, json=data) as response:
660
  if response.status != 200:
661
+ error_msg = (
662
+ f"Request {request_id} failed with status {response.status}"
663
+ )
664
+ if OutputControl.is_verbose():
665
+ print(f"\n{error_msg}")
666
+ raise McpError(ErrorCode.InternalError, error_msg)
667
+ result = await response.json()
668
+ if "error" in result:
669
+ error_msg = (
670
+ f"Request {request_id} returned error: {result['error']}"
671
+ )
672
+ if OutputControl.is_verbose():
673
+ print(f"\n{error_msg}")
674
+ raise McpError(ErrorCode.InternalError, error_msg)
675
+ return result
676
  except Exception as e:
677
+ error_msg = f"Request {request_id} failed: {str(e)}"
678
+ if OutputControl.is_verbose():
679
+ print(f"\n{error_msg}")
680
+ raise McpError(ErrorCode.InternalError, error_msg)
681
 
682
  async def run_concurrent_requests():
683
  prompts = ["第一个问题", "第二个问题", "第三个问题", "第四个问题", "第五个问题"]
684
 
685
  async with get_session() as session:
686
+ tasks = [
687
+ make_request(session, prompt, i + 1) for i, prompt in enumerate(prompts)
688
+ ]
689
+ results = await asyncio.gather(*tasks, return_exceptions=True)
690
+
691
+ success_results = []
692
+ error_messages = []
693
+
694
+ for i, result in enumerate(results):
695
+ if isinstance(result, Exception):
696
+ error_messages.append(f"Request {i+1} failed: {str(result)}")
697
+ else:
698
+ success_results.append((i + 1, result))
699
+
700
+ if error_messages:
701
+ for req_id, result in success_results:
702
+ if OutputControl.is_verbose():
703
+ print(f"\nRequest {req_id} succeeded:")
704
+ print_json_response(result)
705
+
706
+ error_summary = "\n".join(error_messages)
707
+ raise McpError(
708
+ ErrorCode.InternalError,
709
+ f"Some concurrent requests failed:\n{error_summary}",
710
+ )
711
+
712
  return results
713
 
714
  if OutputControl.is_verbose():
715
  print("\n=== Testing concurrent generate requests ===")
716
 
717
  # Run concurrent requests
718
+ try:
719
+ results = asyncio.run(run_concurrent_requests())
720
+ # all success, print out results
721
+ for i, result in enumerate(results, 1):
722
+ print(f"\nRequest {i} result:")
723
+ print_json_response(result)
724
+ except McpError:
725
+ # error message already printed
726
+ raise
727
 
728
 
729
  def get_test_cases() -> Dict[str, Callable]: