partoneplay commited on
Commit
f4c9977
·
1 Parent(s): 4c06d73

Add support for OpenAI Compatible Streaming output

Browse files
examples/lightrag_openai_compatible_stream_demo.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import inspect
3
+ from lightrag import LightRAG
4
+ from lightrag.llm import openai_complete, openai_embedding
5
+ from lightrag.utils import EmbeddingFunc
6
+ from lightrag.lightrag import always_get_an_event_loop
7
+ from lightrag import QueryParam
8
+
9
+ # WorkingDir
10
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
11
+ WORKING_DIR = os.path.join(ROOT_DIR, "dickens")
12
+ if not os.path.exists(WORKING_DIR):
13
+ os.mkdir(WORKING_DIR)
14
+ print(f"WorkingDir: {WORKING_DIR}")
15
+
16
+ api_key = "empty"
17
+ rag = LightRAG(
18
+ working_dir=WORKING_DIR,
19
+ llm_model_func=openai_complete,
20
+ llm_model_name="qwen2.5-14b-instruct@4bit",
21
+ llm_model_max_async=4,
22
+ llm_model_max_token_size=32768,
23
+ llm_model_kwargs={"base_url": "http://127.0.0.1:1234/v1", "api_key": api_key},
24
+ embedding_func=EmbeddingFunc(
25
+ embedding_dim=1024,
26
+ max_token_size=8192,
27
+ func=lambda texts: openai_embedding(
28
+ texts=texts,
29
+ model="text-embedding-bge-m3",
30
+ base_url="http://127.0.0.1:1234/v1",
31
+ api_key=api_key,
32
+ ),
33
+ ),
34
+ )
35
+
36
+ with open("./book.txt", "r", encoding="utf-8") as f:
37
+ rag.insert(f.read())
38
+
39
+ resp = rag.query(
40
+ "What are the top themes in this story?",
41
+ param=QueryParam(mode="hybrid", stream=True),
42
+ )
43
+
44
+
45
+ async def print_stream(stream):
46
+ async for chunk in stream:
47
+ if chunk:
48
+ print(chunk, end="", flush=True)
49
+
50
+
51
+ loop = always_get_an_event_loop()
52
+ if inspect.isasyncgen(resp):
53
+ loop.run_until_complete(print_stream(resp))
54
+ else:
55
+ print(resp)
lightrag/llm.py CHANGED
@@ -91,26 +91,40 @@ async def openai_complete_if_cache(
91
  response = await openai_async_client.chat.completions.create(
92
  model=model, messages=messages, **kwargs
93
  )
94
- content = response.choices[0].message.content
95
- if r"\u" in content:
96
- content = content.encode("utf-8").decode("unicode_escape")
97
 
98
- # Save to cache
99
- await save_to_cache(
100
- hashing_kv,
101
- CacheData(
102
- args_hash=args_hash,
103
- content=content,
104
- model=model,
105
- prompt=prompt,
106
- quantized=quantized,
107
- min_val=min_val,
108
- max_val=max_val,
109
- mode=mode,
110
- ),
111
- )
112
 
113
- return content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  @retry(
@@ -431,7 +445,7 @@ async def ollama_model_if_cache(
431
 
432
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
433
  if stream:
434
- """ cannot cache stream response """
435
 
436
  async def inner():
437
  async for chunk in response:
@@ -613,6 +627,22 @@ class GPTKeywordExtractionFormat(BaseModel):
613
  low_level_keywords: List[str]
614
 
615
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  async def gpt_4o_complete(
617
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
618
  ) -> str:
@@ -1089,12 +1119,14 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
1089
  mode_cache[cache_data.args_hash] = {
1090
  "return": cache_data.content,
1091
  "model": cache_data.model,
1092
- "embedding": cache_data.quantized.tobytes().hex()
1093
- if cache_data.quantized is not None
1094
- else None,
1095
- "embedding_shape": cache_data.quantized.shape
1096
- if cache_data.quantized is not None
1097
- else None,
 
 
1098
  "embedding_min": cache_data.min_val,
1099
  "embedding_max": cache_data.max_val,
1100
  "original_prompt": cache_data.prompt,
 
91
  response = await openai_async_client.chat.completions.create(
92
  model=model, messages=messages, **kwargs
93
  )
 
 
 
94
 
95
+ if hasattr(response, "__aiter__"):
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
+ async def inner():
98
+ async for chunk in response:
99
+ content = chunk.choices[0].delta.content
100
+ if content is None:
101
+ continue
102
+ if r"\u" in content:
103
+ content = content.encode("utf-8").decode("unicode_escape")
104
+ yield content
105
+
106
+ return inner()
107
+ else:
108
+ content = response.choices[0].message.content
109
+ if r"\u" in content:
110
+ content = content.encode("utf-8").decode("unicode_escape")
111
+
112
+ # Save to cache
113
+ await save_to_cache(
114
+ hashing_kv,
115
+ CacheData(
116
+ args_hash=args_hash,
117
+ content=content,
118
+ model=model,
119
+ prompt=prompt,
120
+ quantized=quantized,
121
+ min_val=min_val,
122
+ max_val=max_val,
123
+ mode=mode,
124
+ ),
125
+ )
126
+
127
+ return content
128
 
129
 
130
  @retry(
 
445
 
446
  response = await ollama_client.chat(model=model, messages=messages, **kwargs)
447
  if stream:
448
+ """cannot cache stream response"""
449
 
450
  async def inner():
451
  async for chunk in response:
 
627
  low_level_keywords: List[str]
628
 
629
 
630
+ async def openai_complete(
631
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
632
+ ) -> Union[str, AsyncIterator[str]]:
633
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
634
+ if keyword_extraction:
635
+ kwargs["response_format"] = "json"
636
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
637
+ return await openai_complete_if_cache(
638
+ model_name,
639
+ prompt,
640
+ system_prompt=system_prompt,
641
+ history_messages=history_messages,
642
+ **kwargs,
643
+ )
644
+
645
+
646
  async def gpt_4o_complete(
647
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
648
  ) -> str:
 
1119
  mode_cache[cache_data.args_hash] = {
1120
  "return": cache_data.content,
1121
  "model": cache_data.model,
1122
+ "embedding": (
1123
+ cache_data.quantized.tobytes().hex()
1124
+ if cache_data.quantized is not None
1125
+ else None
1126
+ ),
1127
+ "embedding_shape": (
1128
+ cache_data.quantized.shape if cache_data.quantized is not None else None
1129
+ ),
1130
  "embedding_min": cache_data.min_val,
1131
  "embedding_max": cache_data.max_val,
1132
  "original_prompt": cache_data.prompt,