tackhwa commited on
Commit
af81714
·
1 Parent(s): 29c715d

fix hf bug

Browse files
examples/lightrag_siliconcloud_demo.py CHANGED
@@ -19,7 +19,7 @@ async def llm_model_func(
19
  prompt,
20
  system_prompt=system_prompt,
21
  history_messages=history_messages,
22
- api_key=os.getenv("UPSTAGE_API_KEY"),
23
  base_url="https://api.siliconflow.cn/v1/",
24
  **kwargs,
25
  )
@@ -29,7 +29,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
29
  return await siliconcloud_embedding(
30
  texts,
31
  model="netease-youdao/bce-embedding-base_v1",
32
- api_key=os.getenv("UPSTAGE_API_KEY"),
33
  max_token_size=512
34
  )
35
 
 
19
  prompt,
20
  system_prompt=system_prompt,
21
  history_messages=history_messages,
22
+ api_key=os.getenv("SILICONFLOW_API_KEY"),
23
  base_url="https://api.siliconflow.cn/v1/",
24
  **kwargs,
25
  )
 
29
  return await siliconcloud_embedding(
30
  texts,
31
  model="netease-youdao/bce-embedding-base_v1",
32
+ api_key=os.getenv("SILICONFLOW_API_KEY"),
33
  max_token_size=512
34
  )
35
 
lightrag/llm.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import copy
 
3
  import json
4
  import aioboto3
5
  import aiohttp
@@ -202,15 +203,22 @@ async def bedrock_complete_if_cache(
202
  return response["output"]["message"]["content"][0]["text"]
203
 
204
 
 
 
 
 
 
 
 
 
205
  async def hf_model_if_cache(
206
  model, prompt, system_prompt=None, history_messages=[], **kwargs
207
  ) -> str:
208
  model_name = model
209
- hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto")
210
  if hf_tokenizer.pad_token is None:
211
  # print("use eos token")
212
  hf_tokenizer.pad_token = hf_tokenizer.eos_token
213
- hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
214
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
215
  messages = []
216
  if system_prompt:
 
1
  import os
2
  import copy
3
+ from functools import lru_cache
4
  import json
5
  import aioboto3
6
  import aiohttp
 
203
  return response["output"]["message"]["content"][0]["text"]
204
 
205
 
206
+ @lru_cache(maxsize=1)
207
+ def initialize_hf_model(model_name):
208
+ hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
209
+ hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
210
+
211
+ return hf_model, hf_tokenizer
212
+
213
+
214
  async def hf_model_if_cache(
215
  model, prompt, system_prompt=None, history_messages=[], **kwargs
216
  ) -> str:
217
  model_name = model
218
+ hf_model, hf_tokenizer = initialize_hf_model(model_name)
219
  if hf_tokenizer.pad_token is None:
220
  # print("use eos token")
221
  hf_tokenizer.pad_token = hf_tokenizer.eos_token
 
222
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
223
  messages = []
224
  if system_prompt: