fix hf bug
Browse files- examples/lightrag_siliconcloud_demo.py +2 -2
- lightrag/llm.py +10 -2
examples/lightrag_siliconcloud_demo.py
CHANGED
@@ -19,7 +19,7 @@ async def llm_model_func(
|
|
19 |
prompt,
|
20 |
system_prompt=system_prompt,
|
21 |
history_messages=history_messages,
|
22 |
-
api_key=os.getenv("
|
23 |
base_url="https://api.siliconflow.cn/v1/",
|
24 |
**kwargs,
|
25 |
)
|
@@ -29,7 +29,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
|
|
29 |
return await siliconcloud_embedding(
|
30 |
texts,
|
31 |
model="netease-youdao/bce-embedding-base_v1",
|
32 |
-
api_key=os.getenv("
|
33 |
max_token_size=512
|
34 |
)
|
35 |
|
|
|
19 |
prompt,
|
20 |
system_prompt=system_prompt,
|
21 |
history_messages=history_messages,
|
22 |
+
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
23 |
base_url="https://api.siliconflow.cn/v1/",
|
24 |
**kwargs,
|
25 |
)
|
|
|
29 |
return await siliconcloud_embedding(
|
30 |
texts,
|
31 |
model="netease-youdao/bce-embedding-base_v1",
|
32 |
+
api_key=os.getenv("SILICONFLOW_API_KEY"),
|
33 |
max_token_size=512
|
34 |
)
|
35 |
|
lightrag/llm.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
import copy
|
|
|
3 |
import json
|
4 |
import aioboto3
|
5 |
import aiohttp
|
@@ -202,15 +203,22 @@ async def bedrock_complete_if_cache(
|
|
202 |
return response["output"]["message"]["content"][0]["text"]
|
203 |
|
204 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
async def hf_model_if_cache(
|
206 |
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
207 |
) -> str:
|
208 |
model_name = model
|
209 |
-
hf_tokenizer =
|
210 |
if hf_tokenizer.pad_token is None:
|
211 |
# print("use eos token")
|
212 |
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
213 |
-
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
|
214 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
215 |
messages = []
|
216 |
if system_prompt:
|
|
|
1 |
import os
|
2 |
import copy
|
3 |
+
from functools import lru_cache
|
4 |
import json
|
5 |
import aioboto3
|
6 |
import aiohttp
|
|
|
203 |
return response["output"]["message"]["content"][0]["text"]
|
204 |
|
205 |
|
206 |
+
@lru_cache(maxsize=1)
|
207 |
+
def initialize_hf_model(model_name):
|
208 |
+
hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
|
209 |
+
hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
|
210 |
+
|
211 |
+
return hf_model, hf_tokenizer
|
212 |
+
|
213 |
+
|
214 |
async def hf_model_if_cache(
|
215 |
model, prompt, system_prompt=None, history_messages=[], **kwargs
|
216 |
) -> str:
|
217 |
model_name = model
|
218 |
+
hf_model, hf_tokenizer = initialize_hf_model(model_name)
|
219 |
if hf_tokenizer.pad_token is None:
|
220 |
# print("use eos token")
|
221 |
hf_tokenizer.pad_token = hf_tokenizer.eos_token
|
|
|
222 |
hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
|
223 |
messages = []
|
224 |
if system_prompt:
|