hunkim commited on
Commit
046051b
·
1 Parent(s): 68483c2

Added OpenAI compatible options and examples

Browse files
examples/lightrag_openai_compatible_demo.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm import openai_complete_if_cache, openai_embedding
5
+ from lightrag.utils import EmbeddingFunc
6
+ import numpy as np
7
+
8
+ WORKING_DIR = "./dickens"
9
+
10
+ if not os.path.exists(WORKING_DIR):
11
+ os.mkdir(WORKING_DIR)
12
+
13
+ async def llm_model_func(
14
+ prompt, system_prompt=None, history_messages=[], **kwargs
15
+ ) -> str:
16
+ return await openai_complete_if_cache(
17
+ "solar-mini",
18
+ prompt,
19
+ system_prompt=system_prompt,
20
+ history_messages=history_messages,
21
+ api_key=os.getenv("UPSTAGE_API_KEY"),
22
+ base_url="https://api.upstage.ai/v1/solar",
23
+ **kwargs
24
+ )
25
+
26
+ async def embedding_func(texts: list[str]) -> np.ndarray:
27
+ return await openai_embedding(
28
+ texts,
29
+ model="solar-embedding-1-large-query",
30
+ api_key=os.getenv("UPSTAGE_API_KEY"),
31
+ base_url="https://api.upstage.ai/v1/solar"
32
+ )
33
+
34
+ # function test
35
+ async def test_funcs():
36
+ result = await llm_model_func("How are you?")
37
+ print("llm_model_func: ", result)
38
+
39
+ result = await embedding_func(["How are you?"])
40
+ print("embedding_func: ", result)
41
+
42
+ asyncio.run(test_funcs())
43
+
44
+
45
+ rag = LightRAG(
46
+ working_dir=WORKING_DIR,
47
+ llm_model_func=llm_model_func,
48
+ embedding_func=EmbeddingFunc(
49
+ embedding_dim=4096,
50
+ max_token_size=8192,
51
+ func=embedding_func
52
+ )
53
+ )
54
+
55
+
56
+ with open("./book.txt") as f:
57
+ rag.insert(f.read())
58
+
59
+ # Perform naive search
60
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
61
+
62
+ # Perform local search
63
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
64
+
65
+ # Perform global search
66
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
67
+
68
+ # Perform hybrid search
69
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
lightrag/llm.py CHANGED
@@ -19,9 +19,12 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
20
  )
21
  async def openai_complete_if_cache(
22
- model, prompt, system_prompt=None, history_messages=[], **kwargs
23
  ) -> str:
24
- openai_async_client = AsyncOpenAI()
 
 
 
25
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
26
  messages = []
27
  if system_prompt:
@@ -133,10 +136,13 @@ async def hf_model_complete(
133
  wait=wait_exponential(multiplier=1, min=4, max=10),
134
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
135
  )
136
- async def openai_embedding(texts: list[str]) -> np.ndarray:
137
- openai_async_client = AsyncOpenAI()
 
 
 
138
  response = await openai_async_client.embeddings.create(
139
- model="text-embedding-3-small", input=texts, encoding_format="float"
140
  )
141
  return np.array([dp.embedding for dp in response.data])
142
 
 
19
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
20
  )
21
  async def openai_complete_if_cache(
22
+ model, prompt, system_prompt=None, history_messages=[], base_url=None, api_key=None, **kwargs
23
  ) -> str:
24
+ if api_key:
25
+ os.environ["OPENAI_API_KEY"] = api_key
26
+
27
+ openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
28
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
29
  messages = []
30
  if system_prompt:
 
136
  wait=wait_exponential(multiplier=1, min=4, max=10),
137
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
138
  )
139
+ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-small", base_url: str = None, api_key: str = None) -> np.ndarray:
140
+ if api_key:
141
+ os.environ["OPENAI_API_KEY"] = api_key
142
+
143
+ openai_async_client = AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
144
  response = await openai_async_client.embeddings.create(
145
+ model=model, input=texts, encoding_format="float"
146
  )
147
  return np.array([dp.embedding for dp in response.data])
148