ParisNeo commited on
Commit
0553d6a
·
1 Parent(s): 40504da

Separated llms from the main llm.py file and fixed some deprication bugs

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +2 -0
  2. README.md +5 -5
  3. config.ini +0 -13
  4. examples/insert_custom_kg.py +1 -1
  5. examples/lightrag_api_ollama_demo.py +2 -2
  6. examples/lightrag_api_open_webui_demo.py +1 -1
  7. examples/lightrag_api_openai_compatible_demo.py +2 -2
  8. examples/lightrag_api_oracle_demo.py +2 -2
  9. examples/lightrag_bedrock_demo.py +2 -2
  10. examples/lightrag_hf_demo.py +2 -2
  11. examples/lightrag_jinaai_demo.py +3 -2
  12. examples/lightrag_lmdeploy_demo.py +3 -2
  13. examples/lightrag_nvidia_demo.py +3 -3
  14. examples/lightrag_ollama_age_demo.py +2 -2
  15. examples/lightrag_ollama_demo.py +2 -2
  16. examples/lightrag_ollama_gremlin_demo.py +2 -2
  17. examples/lightrag_ollama_neo4j_milvus_mongo_demo.py +1 -1
  18. examples/lightrag_openai_compatible_demo.py +2 -2
  19. examples/lightrag_openai_compatible_demo_embedding_cache.py +2 -2
  20. examples/lightrag_openai_compatible_stream_demo.py +2 -2
  21. examples/lightrag_openai_demo.py +1 -1
  22. examples/lightrag_openai_neo4j_milvus_redis_demo.py +1 -1
  23. examples/lightrag_oracle_demo.py +2 -2
  24. examples/lightrag_siliconcloud_demo.py +2 -1
  25. examples/lightrag_zhipu_demo.py +1 -1
  26. examples/lightrag_zhipu_postgres_demo.py +1 -1
  27. examples/test.py +1 -1
  28. examples/test_chromadb.py +2 -2
  29. examples/test_neo4j.py +1 -1
  30. examples/test_split_by_character.ipynb +3 -3
  31. examples/vram_management_demo.py +2 -2
  32. lightrag/api/lightrag_server.py +16 -6
  33. lightrag/api/requirements.txt +0 -1
  34. lightrag/exceptions.py +55 -0
  35. lightrag/kg/redis_impl.py +3 -2
  36. lightrag/lightrag.py +2 -6
  37. lightrag/llm.py +3 -1207
  38. lightrag/llm/__init__.py +0 -0
  39. lightrag/llm/azure_openai.py +188 -0
  40. lightrag/llm/bedrock.py +229 -0
  41. lightrag/llm/hf.py +187 -0
  42. lightrag/llm/jina.py +104 -0
  43. lightrag/llm/lmdeploy.py +190 -0
  44. lightrag/llm/lollms.py +222 -0
  45. lightrag/llm/nvidia_openai.py +112 -0
  46. lightrag/llm/ollama.py +155 -0
  47. lightrag/llm/openai.py +232 -0
  48. lightrag/llm/siliconcloud.py +121 -0
  49. lightrag/llm/zhipu.py +250 -0
  50. lightrag/storage.py +2 -0
.gitignore CHANGED
@@ -22,3 +22,5 @@ venv/
22
  examples/input/
23
  examples/output/
24
  .DS_Store
 
 
 
22
  examples/input/
23
  examples/output/
24
  .DS_Store
25
+ #Remove config.ini from repo
26
+ *.ini
README.md CHANGED
@@ -81,7 +81,7 @@ Use the below Python snippet (in a script) to initialize LightRAG and perform qu
81
  ```python
82
  import os
83
  from lightrag import LightRAG, QueryParam
84
- from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
85
 
86
  #########
87
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
@@ -177,7 +177,7 @@ async def llm_model_func(
177
  )
178
 
179
  async def embedding_func(texts: list[str]) -> np.ndarray:
180
- return await openai_embedding(
181
  texts,
182
  model="solar-embedding-1-large-query",
183
  api_key=os.getenv("UPSTAGE_API_KEY"),
@@ -233,7 +233,7 @@ If you want to use Ollama models, you need to pull model you plan to use and emb
233
  Then you only need to set LightRAG as follows:
234
 
235
  ```python
236
- from lightrag.llm import ollama_model_complete, ollama_embedding
237
  from lightrag.utils import EmbeddingFunc
238
 
239
  # Initialize LightRAG with Ollama model
@@ -245,7 +245,7 @@ rag = LightRAG(
245
  embedding_func=EmbeddingFunc(
246
  embedding_dim=768,
247
  max_token_size=8192,
248
- func=lambda texts: ollama_embedding(
249
  texts,
250
  embed_model="nomic-embed-text"
251
  )
@@ -690,7 +690,7 @@ if __name__ == "__main__":
690
  | **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` |
691
  | **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
692
  | **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` |
693
- | **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embedding` |
694
  | **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` |
695
  | **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` |
696
  | **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` |
 
81
  ```python
82
  import os
83
  from lightrag import LightRAG, QueryParam
84
+ from lightrag.llm.openai import gpt_4o_mini_complete, gpt_4o_complete
85
 
86
  #########
87
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
 
177
  )
178
 
179
  async def embedding_func(texts: list[str]) -> np.ndarray:
180
+ return await openai_embed(
181
  texts,
182
  model="solar-embedding-1-large-query",
183
  api_key=os.getenv("UPSTAGE_API_KEY"),
 
233
  Then you only need to set LightRAG as follows:
234
 
235
  ```python
236
+ from lightrag.llm.ollama import ollama_model_complete, ollama_embed
237
  from lightrag.utils import EmbeddingFunc
238
 
239
  # Initialize LightRAG with Ollama model
 
245
  embedding_func=EmbeddingFunc(
246
  embedding_dim=768,
247
  max_token_size=8192,
248
+ func=lambda texts: ollama_embed(
249
  texts,
250
  embed_model="nomic-embed-text"
251
  )
 
690
  | **entity\_summary\_to\_max\_tokens** | `int` | Maximum token size for each entity summary | `500` |
691
  | **node\_embedding\_algorithm** | `str` | Algorithm for node embedding (currently not used) | `node2vec` |
692
  | **node2vec\_params** | `dict` | Parameters for node embedding | `{"dimensions": 1536,"num_walks": 10,"walk_length": 40,"window_size": 2,"iterations": 3,"random_seed": 3,}` |
693
+ | **embedding\_func** | `EmbeddingFunc` | Function to generate embedding vectors from text | `openai_embed` |
694
  | **embedding\_batch\_num** | `int` | Maximum batch size for embedding processes (multiple texts sent per batch) | `32` |
695
  | **embedding\_func\_max\_async** | `int` | Maximum number of concurrent asynchronous embedding processes | `16` |
696
  | **llm\_model\_func** | `callable` | Function for LLM generation | `gpt_4o_mini_complete` |
config.ini DELETED
@@ -1,13 +0,0 @@
1
- [redis]
2
- uri = redis://localhost:6379
3
-
4
- [neo4j]
5
- uri = #
6
- username = neo4j
7
- password = 12345678
8
-
9
- [milvus]
10
- uri = #
11
- user = root
12
- password = Milvus
13
- db_name = lightrag
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/insert_custom_kg.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  from lightrag import LightRAG
3
- from lightrag.llm import gpt_4o_mini_complete
4
  #########
5
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
6
  # import nest_asyncio
 
1
  import os
2
  from lightrag import LightRAG
3
+ from lightrag.llm.openai import gpt_4o_mini_complete
4
  #########
5
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
6
  # import nest_asyncio
examples/lightrag_api_ollama_demo.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException, File, UploadFile
2
  from pydantic import BaseModel
3
  import os
4
  from lightrag import LightRAG, QueryParam
5
- from lightrag.llm import ollama_embedding, ollama_model_complete
6
  from lightrag.utils import EmbeddingFunc
7
  from typing import Optional
8
  import asyncio
@@ -38,7 +38,7 @@ rag = LightRAG(
38
  embedding_func=EmbeddingFunc(
39
  embedding_dim=768,
40
  max_token_size=8192,
41
- func=lambda texts: ollama_embedding(
42
  texts, embed_model="nomic-embed-text", host="http://localhost:11434"
43
  ),
44
  ),
 
2
  from pydantic import BaseModel
3
  import os
4
  from lightrag import LightRAG, QueryParam
5
+ from lightrag.llm.ollama import ollama_embed, ollama_model_complete
6
  from lightrag.utils import EmbeddingFunc
7
  from typing import Optional
8
  import asyncio
 
38
  embedding_func=EmbeddingFunc(
39
  embedding_dim=768,
40
  max_token_size=8192,
41
+ func=lambda texts: ollama_embed(
42
  texts, embed_model="nomic-embed-text", host="http://localhost:11434"
43
  ),
44
  ),
examples/lightrag_api_open_webui_demo.py CHANGED
@@ -9,7 +9,7 @@ from typing import Optional
9
  import os
10
  import logging
11
  from lightrag import LightRAG, QueryParam
12
- from lightrag.llm import ollama_model_complete, ollama_embed
13
  from lightrag.utils import EmbeddingFunc
14
 
15
  import nest_asyncio
 
9
  import os
10
  import logging
11
  from lightrag import LightRAG, QueryParam
12
+ from lightrag.llm.ollama import ollama_model_complete, ollama_embed
13
  from lightrag.utils import EmbeddingFunc
14
 
15
  import nest_asyncio
examples/lightrag_api_openai_compatible_demo.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import FastAPI, HTTPException, File, UploadFile
2
  from pydantic import BaseModel
3
  import os
4
  from lightrag import LightRAG, QueryParam
5
- from lightrag.llm import openai_complete_if_cache, openai_embedding
6
  from lightrag.utils import EmbeddingFunc
7
  import numpy as np
8
  from typing import Optional
@@ -48,7 +48,7 @@ async def llm_model_func(
48
 
49
 
50
  async def embedding_func(texts: list[str]) -> np.ndarray:
51
- return await openai_embedding(
52
  texts,
53
  model=EMBEDDING_MODEL,
54
  )
 
2
  from pydantic import BaseModel
3
  import os
4
  from lightrag import LightRAG, QueryParam
5
+ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
6
  from lightrag.utils import EmbeddingFunc
7
  import numpy as np
8
  from typing import Optional
 
48
 
49
 
50
  async def embedding_func(texts: list[str]) -> np.ndarray:
51
+ return await openai_embed(
52
  texts,
53
  model=EMBEDDING_MODEL,
54
  )
examples/lightrag_api_oracle_demo.py CHANGED
@@ -13,7 +13,7 @@ from pathlib import Path
13
  import asyncio
14
  import nest_asyncio
15
  from lightrag import LightRAG, QueryParam
16
- from lightrag.llm import openai_complete_if_cache, openai_embedding
17
  from lightrag.utils import EmbeddingFunc
18
  import numpy as np
19
 
@@ -64,7 +64,7 @@ async def llm_model_func(
64
 
65
 
66
  async def embedding_func(texts: list[str]) -> np.ndarray:
67
- return await openai_embedding(
68
  texts,
69
  model=EMBEDDING_MODEL,
70
  api_key=APIKEY,
 
13
  import asyncio
14
  import nest_asyncio
15
  from lightrag import LightRAG, QueryParam
16
+ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
17
  from lightrag.utils import EmbeddingFunc
18
  import numpy as np
19
 
 
64
 
65
 
66
  async def embedding_func(texts: list[str]) -> np.ndarray:
67
+ return await openai_embed(
68
  texts,
69
  model=EMBEDDING_MODEL,
70
  api_key=APIKEY,
examples/lightrag_bedrock_demo.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import logging
7
 
8
  from lightrag import LightRAG, QueryParam
9
- from lightrag.llm import bedrock_complete, bedrock_embedding
10
  from lightrag.utils import EmbeddingFunc
11
 
12
  logging.getLogger("aiobotocore").setLevel(logging.WARNING)
@@ -20,7 +20,7 @@ rag = LightRAG(
20
  llm_model_func=bedrock_complete,
21
  llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
22
  embedding_func=EmbeddingFunc(
23
- embedding_dim=1024, max_token_size=8192, func=bedrock_embedding
24
  ),
25
  )
26
 
 
6
  import logging
7
 
8
  from lightrag import LightRAG, QueryParam
9
+ from lightrag.llm.bedrock import bedrock_complete, bedrock_embed
10
  from lightrag.utils import EmbeddingFunc
11
 
12
  logging.getLogger("aiobotocore").setLevel(logging.WARNING)
 
20
  llm_model_func=bedrock_complete,
21
  llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
22
  embedding_func=EmbeddingFunc(
23
+ embedding_dim=1024, max_token_size=8192, func=bedrock_embed
24
  ),
25
  )
26
 
examples/lightrag_hf_demo.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
 
3
  from lightrag import LightRAG, QueryParam
4
- from lightrag.llm import hf_model_complete, hf_embedding
5
  from lightrag.utils import EmbeddingFunc
6
  from transformers import AutoModel, AutoTokenizer
7
 
@@ -17,7 +17,7 @@ rag = LightRAG(
17
  embedding_func=EmbeddingFunc(
18
  embedding_dim=384,
19
  max_token_size=5000,
20
- func=lambda texts: hf_embedding(
21
  texts,
22
  tokenizer=AutoTokenizer.from_pretrained(
23
  "sentence-transformers/all-MiniLM-L6-v2"
 
1
  import os
2
 
3
  from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.hf import hf_model_complete, hf_embed
5
  from lightrag.utils import EmbeddingFunc
6
  from transformers import AutoModel, AutoTokenizer
7
 
 
17
  embedding_func=EmbeddingFunc(
18
  embedding_dim=384,
19
  max_token_size=5000,
20
+ func=lambda texts: hf_embed(
21
  texts,
22
  tokenizer=AutoTokenizer.from_pretrained(
23
  "sentence-transformers/all-MiniLM-L6-v2"
examples/lightrag_jinaai_demo.py CHANGED
@@ -1,13 +1,14 @@
1
  import numpy as np
2
  from lightrag import LightRAG, QueryParam
3
  from lightrag.utils import EmbeddingFunc
4
- from lightrag.llm import jina_embedding, openai_complete_if_cache
 
5
  import os
6
  import asyncio
7
 
8
 
9
  async def embedding_func(texts: list[str]) -> np.ndarray:
10
- return await jina_embedding(texts, api_key="YourJinaAPIKey")
11
 
12
 
13
  WORKING_DIR = "./dickens"
 
1
  import numpy as np
2
  from lightrag import LightRAG, QueryParam
3
  from lightrag.utils import EmbeddingFunc
4
+ from lightrag.llm.jina import jina_embed
5
+ from lightrag.llm.openai import openai_complete_if_cache
6
  import os
7
  import asyncio
8
 
9
 
10
  async def embedding_func(texts: list[str]) -> np.ndarray:
11
+ return await jina_embed(texts, api_key="YourJinaAPIKey")
12
 
13
 
14
  WORKING_DIR = "./dickens"
examples/lightrag_lmdeploy_demo.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
 
3
  from lightrag import LightRAG, QueryParam
4
- from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
 
5
  from lightrag.utils import EmbeddingFunc
6
  from transformers import AutoModel, AutoTokenizer
7
 
@@ -42,7 +43,7 @@ rag = LightRAG(
42
  embedding_func=EmbeddingFunc(
43
  embedding_dim=384,
44
  max_token_size=5000,
45
- func=lambda texts: hf_embedding(
46
  texts,
47
  tokenizer=AutoTokenizer.from_pretrained(
48
  "sentence-transformers/all-MiniLM-L6-v2"
 
1
  import os
2
 
3
  from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.lmdeploy import lmdeploy_model_if_cache
5
+ from lightrag.llm.hf import hf_embed
6
  from lightrag.utils import EmbeddingFunc
7
  from transformers import AutoModel, AutoTokenizer
8
 
 
43
  embedding_func=EmbeddingFunc(
44
  embedding_dim=384,
45
  max_token_size=5000,
46
+ func=lambda texts: hf_embed(
47
  texts,
48
  tokenizer=AutoTokenizer.from_pretrained(
49
  "sentence-transformers/all-MiniLM-L6-v2"
examples/lightrag_nvidia_demo.py CHANGED
@@ -3,7 +3,7 @@ import asyncio
3
  from lightrag import LightRAG, QueryParam
4
  from lightrag.llm import (
5
  openai_complete_if_cache,
6
- nvidia_openai_embedding,
7
  )
8
  from lightrag.utils import EmbeddingFunc
9
  import numpy as np
@@ -47,7 +47,7 @@ nvidia_embed_model = "nvidia/nv-embedqa-e5-v5"
47
 
48
 
49
  async def indexing_embedding_func(texts: list[str]) -> np.ndarray:
50
- return await nvidia_openai_embedding(
51
  texts,
52
  model=nvidia_embed_model, # maximum 512 token
53
  # model="nvidia/llama-3.2-nv-embedqa-1b-v1",
@@ -60,7 +60,7 @@ async def indexing_embedding_func(texts: list[str]) -> np.ndarray:
60
 
61
 
62
  async def query_embedding_func(texts: list[str]) -> np.ndarray:
63
- return await nvidia_openai_embedding(
64
  texts,
65
  model=nvidia_embed_model, # maximum 512 token
66
  # model="nvidia/llama-3.2-nv-embedqa-1b-v1",
 
3
  from lightrag import LightRAG, QueryParam
4
  from lightrag.llm import (
5
  openai_complete_if_cache,
6
+ nvidia_openai_embed,
7
  )
8
  from lightrag.utils import EmbeddingFunc
9
  import numpy as np
 
47
 
48
 
49
  async def indexing_embedding_func(texts: list[str]) -> np.ndarray:
50
+ return await nvidia_openai_embed(
51
  texts,
52
  model=nvidia_embed_model, # maximum 512 token
53
  # model="nvidia/llama-3.2-nv-embedqa-1b-v1",
 
60
 
61
 
62
  async def query_embedding_func(texts: list[str]) -> np.ndarray:
63
+ return await nvidia_openai_embed(
64
  texts,
65
  model=nvidia_embed_model, # maximum 512 token
66
  # model="nvidia/llama-3.2-nv-embedqa-1b-v1",
examples/lightrag_ollama_age_demo.py CHANGED
@@ -4,7 +4,7 @@ import logging
4
  import os
5
 
6
  from lightrag import LightRAG, QueryParam
7
- from lightrag.llm import ollama_embedding, ollama_model_complete
8
  from lightrag.utils import EmbeddingFunc
9
 
10
  WORKING_DIR = "./dickens_age"
@@ -32,7 +32,7 @@ rag = LightRAG(
32
  embedding_func=EmbeddingFunc(
33
  embedding_dim=768,
34
  max_token_size=8192,
35
- func=lambda texts: ollama_embedding(
36
  texts, embed_model="nomic-embed-text", host="http://localhost:11434"
37
  ),
38
  ),
 
4
  import os
5
 
6
  from lightrag import LightRAG, QueryParam
7
+ from lightrag.llm.ollama import ollama_embed, ollama_model_complete
8
  from lightrag.utils import EmbeddingFunc
9
 
10
  WORKING_DIR = "./dickens_age"
 
32
  embedding_func=EmbeddingFunc(
33
  embedding_dim=768,
34
  max_token_size=8192,
35
+ func=lambda texts: ollama_embed(
36
  texts, embed_model="nomic-embed-text", host="http://localhost:11434"
37
  ),
38
  ),
examples/lightrag_ollama_demo.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import inspect
4
  import logging
5
  from lightrag import LightRAG, QueryParam
6
- from lightrag.llm import ollama_model_complete, ollama_embedding
7
  from lightrag.utils import EmbeddingFunc
8
 
9
  WORKING_DIR = "./dickens"
@@ -23,7 +23,7 @@ rag = LightRAG(
23
  embedding_func=EmbeddingFunc(
24
  embedding_dim=768,
25
  max_token_size=8192,
26
- func=lambda texts: ollama_embedding(
27
  texts, embed_model="nomic-embed-text", host="http://localhost:11434"
28
  ),
29
  ),
 
3
  import inspect
4
  import logging
5
  from lightrag import LightRAG, QueryParam
6
+ from lightrag.llm.ollama import ollama_model_complete, ollama_embed
7
  from lightrag.utils import EmbeddingFunc
8
 
9
  WORKING_DIR = "./dickens"
 
23
  embedding_func=EmbeddingFunc(
24
  embedding_dim=768,
25
  max_token_size=8192,
26
+ func=lambda texts: ollama_embed(
27
  texts, embed_model="nomic-embed-text", host="http://localhost:11434"
28
  ),
29
  ),
examples/lightrag_ollama_gremlin_demo.py CHANGED
@@ -10,7 +10,7 @@ import os
10
  # logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARN)
11
 
12
  from lightrag import LightRAG, QueryParam
13
- from lightrag.llm import ollama_embedding, ollama_model_complete
14
  from lightrag.utils import EmbeddingFunc
15
 
16
  WORKING_DIR = "./dickens_gremlin"
@@ -41,7 +41,7 @@ rag = LightRAG(
41
  embedding_func=EmbeddingFunc(
42
  embedding_dim=768,
43
  max_token_size=8192,
44
- func=lambda texts: ollama_embedding(
45
  texts, embed_model="nomic-embed-text", host="http://localhost:11434"
46
  ),
47
  ),
 
10
  # logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARN)
11
 
12
  from lightrag import LightRAG, QueryParam
13
+ from lightrag.llm.ollama import ollama_embed, ollama_model_complete
14
  from lightrag.utils import EmbeddingFunc
15
 
16
  WORKING_DIR = "./dickens_gremlin"
 
41
  embedding_func=EmbeddingFunc(
42
  embedding_dim=768,
43
  max_token_size=8192,
44
+ func=lambda texts: ollama_embed(
45
  texts, embed_model="nomic-embed-text", host="http://localhost:11434"
46
  ),
47
  ),
examples/lightrag_ollama_neo4j_milvus_mongo_demo.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  from lightrag import LightRAG, QueryParam
3
- from lightrag.llm import ollama_model_complete, ollama_embed
4
  from lightrag.utils import EmbeddingFunc
5
 
6
  # WorkingDir
 
1
  import os
2
  from lightrag import LightRAG, QueryParam
3
+ from lightrag.llm.ollama import ollama_model_complete, ollama_embed
4
  from lightrag.utils import EmbeddingFunc
5
 
6
  # WorkingDir
examples/lightrag_openai_compatible_demo.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import asyncio
3
  from lightrag import LightRAG, QueryParam
4
- from lightrag.llm import openai_complete_if_cache, openai_embedding
5
  from lightrag.utils import EmbeddingFunc
6
  import numpy as np
7
 
@@ -26,7 +26,7 @@ async def llm_model_func(
26
 
27
 
28
  async def embedding_func(texts: list[str]) -> np.ndarray:
29
- return await openai_embedding(
30
  texts,
31
  model="solar-embedding-1-large-query",
32
  api_key=os.getenv("UPSTAGE_API_KEY"),
 
1
  import os
2
  import asyncio
3
  from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
5
  from lightrag.utils import EmbeddingFunc
6
  import numpy as np
7
 
 
26
 
27
 
28
  async def embedding_func(texts: list[str]) -> np.ndarray:
29
+ return await openai_embed(
30
  texts,
31
  model="solar-embedding-1-large-query",
32
  api_key=os.getenv("UPSTAGE_API_KEY"),
examples/lightrag_openai_compatible_demo_embedding_cache.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import asyncio
3
  from lightrag import LightRAG, QueryParam
4
- from lightrag.llm import openai_complete_if_cache, openai_embedding
5
  from lightrag.utils import EmbeddingFunc
6
  import numpy as np
7
 
@@ -26,7 +26,7 @@ async def llm_model_func(
26
 
27
 
28
  async def embedding_func(texts: list[str]) -> np.ndarray:
29
- return await openai_embedding(
30
  texts,
31
  model="solar-embedding-1-large-query",
32
  api_key=os.getenv("UPSTAGE_API_KEY"),
 
1
  import os
2
  import asyncio
3
  from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
5
  from lightrag.utils import EmbeddingFunc
6
  import numpy as np
7
 
 
26
 
27
 
28
  async def embedding_func(texts: list[str]) -> np.ndarray:
29
+ return await openai_embed(
30
  texts,
31
  model="solar-embedding-1-large-query",
32
  api_key=os.getenv("UPSTAGE_API_KEY"),
examples/lightrag_openai_compatible_stream_demo.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import inspect
3
  from lightrag import LightRAG
4
- from lightrag.llm import openai_complete, openai_embedding
5
  from lightrag.utils import EmbeddingFunc
6
  from lightrag.lightrag import always_get_an_event_loop
7
  from lightrag import QueryParam
@@ -24,7 +24,7 @@ rag = LightRAG(
24
  embedding_func=EmbeddingFunc(
25
  embedding_dim=1024,
26
  max_token_size=8192,
27
- func=lambda texts: openai_embedding(
28
  texts=texts,
29
  model="text-embedding-bge-m3",
30
  base_url="http://127.0.0.1:1234/v1",
 
1
  import os
2
  import inspect
3
  from lightrag import LightRAG
4
+ from lightrag.llm import openai_complete, openai_embed
5
  from lightrag.utils import EmbeddingFunc
6
  from lightrag.lightrag import always_get_an_event_loop
7
  from lightrag import QueryParam
 
24
  embedding_func=EmbeddingFunc(
25
  embedding_dim=1024,
26
  max_token_size=8192,
27
+ func=lambda texts: openai_embed(
28
  texts=texts,
29
  model="text-embedding-bge-m3",
30
  base_url="http://127.0.0.1:1234/v1",
examples/lightrag_openai_demo.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
 
3
  from lightrag import LightRAG, QueryParam
4
- from lightrag.llm import gpt_4o_mini_complete
5
 
6
  WORKING_DIR = "./dickens"
7
 
 
1
  import os
2
 
3
  from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.openai import gpt_4o_mini_complete
5
 
6
  WORKING_DIR = "./dickens"
7
 
examples/lightrag_openai_neo4j_milvus_redis_demo.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  from lightrag import LightRAG, QueryParam
3
- from lightrag.llm import ollama_embed, openai_complete_if_cache
4
  from lightrag.utils import EmbeddingFunc
5
 
6
  # WorkingDir
 
1
  import os
2
  from lightrag import LightRAG, QueryParam
3
+ from lightrag.llm.ollama import ollama_embed, openai_complete_if_cache
4
  from lightrag.utils import EmbeddingFunc
5
 
6
  # WorkingDir
examples/lightrag_oracle_demo.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  from pathlib import Path
4
  import asyncio
5
  from lightrag import LightRAG, QueryParam
6
- from lightrag.llm import openai_complete_if_cache, openai_embedding
7
  from lightrag.utils import EmbeddingFunc
8
  import numpy as np
9
  from lightrag.kg.oracle_impl import OracleDB
@@ -42,7 +42,7 @@ async def llm_model_func(
42
 
43
 
44
  async def embedding_func(texts: list[str]) -> np.ndarray:
45
- return await openai_embedding(
46
  texts,
47
  model=EMBEDMODEL,
48
  api_key=APIKEY,
 
3
  from pathlib import Path
4
  import asyncio
5
  from lightrag import LightRAG, QueryParam
6
+ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
7
  from lightrag.utils import EmbeddingFunc
8
  import numpy as np
9
  from lightrag.kg.oracle_impl import OracleDB
 
42
 
43
 
44
  async def embedding_func(texts: list[str]) -> np.ndarray:
45
+ return await openai_embed(
46
  texts,
47
  model=EMBEDMODEL,
48
  api_key=APIKEY,
examples/lightrag_siliconcloud_demo.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  import asyncio
3
  from lightrag import LightRAG, QueryParam
4
- from lightrag.llm import openai_complete_if_cache, siliconcloud_embedding
 
5
  from lightrag.utils import EmbeddingFunc
6
  import numpy as np
7
 
 
1
  import os
2
  import asyncio
3
  from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.openai import openai_complete_if_cache
5
+ from lightrag.llm.siliconcloud import siliconcloud_embedding
6
  from lightrag.utils import EmbeddingFunc
7
  import numpy as np
8
 
examples/lightrag_zhipu_demo.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
 
4
 
5
  from lightrag import LightRAG, QueryParam
6
- from lightrag.llm import zhipu_complete, zhipu_embedding
7
  from lightrag.utils import EmbeddingFunc
8
 
9
  WORKING_DIR = "./dickens"
 
3
 
4
 
5
  from lightrag import LightRAG, QueryParam
6
+ from lightrag.llm.zhipu import zhipu_complete, zhipu_embedding
7
  from lightrag.utils import EmbeddingFunc
8
 
9
  WORKING_DIR = "./dickens"
examples/lightrag_zhipu_postgres_demo.py CHANGED
@@ -6,7 +6,7 @@ from dotenv import load_dotenv
6
 
7
  from lightrag import LightRAG, QueryParam
8
  from lightrag.kg.postgres_impl import PostgreSQLDB
9
- from lightrag.llm import ollama_embedding, zhipu_complete
10
  from lightrag.utils import EmbeddingFunc
11
 
12
  load_dotenv()
 
6
 
7
  from lightrag import LightRAG, QueryParam
8
  from lightrag.kg.postgres_impl import PostgreSQLDB
9
+ from lightrag.llm.zhipu import ollama_embedding, zhipu_complete
10
  from lightrag.utils import EmbeddingFunc
11
 
12
  load_dotenv()
examples/test.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  from lightrag import LightRAG, QueryParam
3
- from lightrag.llm import gpt_4o_mini_complete
4
  #########
5
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
6
  # import nest_asyncio
 
1
  import os
2
  from lightrag import LightRAG, QueryParam
3
+ from lightrag.llm.openai import gpt_4o_mini_complete
4
  #########
5
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
6
  # import nest_asyncio
examples/test_chromadb.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import asyncio
3
  from lightrag import LightRAG, QueryParam
4
- from lightrag.llm import gpt_4o_mini_complete, openai_embedding
5
  from lightrag.utils import EmbeddingFunc
6
  import numpy as np
7
 
@@ -35,7 +35,7 @@ EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
35
 
36
 
37
  async def embedding_func(texts: list[str]) -> np.ndarray:
38
- return await openai_embedding(
39
  texts,
40
  model=EMBEDDING_MODEL,
41
  )
 
1
  import os
2
  import asyncio
3
  from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
5
  from lightrag.utils import EmbeddingFunc
6
  import numpy as np
7
 
 
35
 
36
 
37
  async def embedding_func(texts: list[str]) -> np.ndarray:
38
+ return await openai_embed(
39
  texts,
40
  model=EMBEDDING_MODEL,
41
  )
examples/test_neo4j.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  from lightrag import LightRAG, QueryParam
3
- from lightrag.llm import gpt_4o_mini_complete
4
 
5
 
6
  #########
 
1
  import os
2
  from lightrag import LightRAG, QueryParam
3
+ from lightrag.llm.openai import gpt_4o_mini_complete
4
 
5
 
6
  #########
examples/test_split_by_character.ipynb CHANGED
@@ -16,7 +16,7 @@
16
  "import logging\n",
17
  "import numpy as np\n",
18
  "from lightrag import LightRAG, QueryParam\n",
19
- "from lightrag.llm import openai_complete_if_cache, openai_embedding\n",
20
  "from lightrag.utils import EmbeddingFunc\n",
21
  "import nest_asyncio"
22
  ]
@@ -74,7 +74,7 @@
74
  "\n",
75
  "\n",
76
  "async def embedding_func(texts: list[str]) -> np.ndarray:\n",
77
- " return await openai_embedding(\n",
78
  " texts,\n",
79
  " model=\"ep-20241231173413-pgjmk\",\n",
80
  " api_key=API,\n",
@@ -138,7 +138,7 @@
138
  "\n",
139
  "\n",
140
  "async def embedding_func(texts: list[str]) -> np.ndarray:\n",
141
- " return await openai_embedding(\n",
142
  " texts,\n",
143
  " model=\"ep-20241231173413-pgjmk\",\n",
144
  " api_key=API,\n",
 
16
  "import logging\n",
17
  "import numpy as np\n",
18
  "from lightrag import LightRAG, QueryParam\n",
19
+ "from lightrag.llm.openai import openai_complete_if_cache, openai_embed\n",
20
  "from lightrag.utils import EmbeddingFunc\n",
21
  "import nest_asyncio"
22
  ]
 
74
  "\n",
75
  "\n",
76
  "async def embedding_func(texts: list[str]) -> np.ndarray:\n",
77
+ " return await openai_embed(\n",
78
  " texts,\n",
79
  " model=\"ep-20241231173413-pgjmk\",\n",
80
  " api_key=API,\n",
 
138
  "\n",
139
  "\n",
140
  "async def embedding_func(texts: list[str]) -> np.ndarray:\n",
141
+ " return await openai_embed(\n",
142
  " texts,\n",
143
  " model=\"ep-20241231173413-pgjmk\",\n",
144
  " api_key=API,\n",
examples/vram_management_demo.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import time
3
  from lightrag import LightRAG, QueryParam
4
- from lightrag.llm import ollama_model_complete, ollama_embedding
5
  from lightrag.utils import EmbeddingFunc
6
 
7
  # Working directory and the directory path for text files
@@ -20,7 +20,7 @@ rag = LightRAG(
20
  embedding_func=EmbeddingFunc(
21
  embedding_dim=768,
22
  max_token_size=8192,
23
- func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
24
  ),
25
  )
26
 
 
1
  import os
2
  import time
3
  from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm.ollama import ollama_model_complete, ollama_embed
5
  from lightrag.utils import EmbeddingFunc
6
 
7
  # Working directory and the directory path for text files
 
20
  embedding_func=EmbeddingFunc(
21
  embedding_dim=768,
22
  max_token_size=8192,
23
+ func=lambda texts: ollama_embed(texts, embed_model="nomic-embed-text"),
24
  ),
25
  )
26
 
lightrag/api/lightrag_server.py CHANGED
@@ -8,10 +8,6 @@ import time
8
  import re
9
  from typing import List, Dict, Any, Optional, Union
10
  from lightrag import LightRAG, QueryParam
11
- from lightrag.llm import lollms_model_complete, lollms_embed
12
- from lightrag.llm import ollama_model_complete, ollama_embed
13
- from lightrag.llm import openai_complete_if_cache, openai_embedding
14
- from lightrag.llm import azure_openai_complete_if_cache, azure_openai_embedding
15
  from lightrag.api import __api_version__
16
 
17
  from lightrag.utils import EmbeddingFunc
@@ -720,6 +716,20 @@ def create_app(args):
720
 
721
  # Create working directory if it doesn't exist
722
  Path(args.working_dir).mkdir(parents=True, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
723
 
724
  async def openai_alike_model_complete(
725
  prompt,
@@ -773,13 +783,13 @@ def create_app(args):
773
  api_key=args.embedding_binding_api_key,
774
  )
775
  if args.embedding_binding == "ollama"
776
- else azure_openai_embedding(
777
  texts,
778
  model=args.embedding_model, # no host is used for openai,
779
  api_key=args.embedding_binding_api_key,
780
  )
781
  if args.embedding_binding == "azure_openai"
782
- else openai_embedding(
783
  texts,
784
  model=args.embedding_model, # no host is used for openai,
785
  api_key=args.embedding_binding_api_key,
 
8
  import re
9
  from typing import List, Dict, Any, Optional, Union
10
  from lightrag import LightRAG, QueryParam
 
 
 
 
11
  from lightrag.api import __api_version__
12
 
13
  from lightrag.utils import EmbeddingFunc
 
716
 
717
  # Create working directory if it doesn't exist
718
  Path(args.working_dir).mkdir(parents=True, exist_ok=True)
719
+ if args.llm_binding_host == "lollms" or args.embedding_binding == "lollms":
720
+ from lightrag.llm.lollms import lollms_model_complete, lollms_embed
721
+ if args.llm_binding_host == "ollama" or args.embedding_binding == "ollama":
722
+ from lightrag.llm.ollama import ollama_model_complete, ollama_embed
723
+ if args.llm_binding_host == "openai" or args.embedding_binding == "openai":
724
+ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
725
+ if (
726
+ args.llm_binding_host == "azure_openai"
727
+ or args.embedding_binding == "azure_openai"
728
+ ):
729
+ from lightrag.llm.azure_openai import (
730
+ azure_openai_complete_if_cache,
731
+ azure_openai_embed,
732
+ )
733
 
734
  async def openai_alike_model_complete(
735
  prompt,
 
783
  api_key=args.embedding_binding_api_key,
784
  )
785
  if args.embedding_binding == "ollama"
786
+ else azure_openai_embed(
787
  texts,
788
  model=args.embedding_model, # no host is used for openai,
789
  api_key=args.embedding_binding_api_key,
790
  )
791
  if args.embedding_binding == "azure_openai"
792
+ else openai_embed(
793
  texts,
794
  model=args.embedding_model, # no host is used for openai,
795
  api_key=args.embedding_binding_api_key,
lightrag/api/requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- aioboto3
2
  ascii_colors
3
  fastapi
4
  nano_vectordb
 
 
1
  ascii_colors
2
  fastapi
3
  nano_vectordb
lightrag/exceptions.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ from typing import Literal
3
+
4
+ class APIStatusError(Exception):
5
+ """Raised when an API response has a status code of 4xx or 5xx."""
6
+
7
+ response: httpx.Response
8
+ status_code: int
9
+ request_id: str | None
10
+
11
+ def __init__(self, message: str, *, response: httpx.Response, body: object | None) -> None:
12
+ super().__init__(message, response.request, body=body)
13
+ self.response = response
14
+ self.status_code = response.status_code
15
+ self.request_id = response.headers.get("x-request-id")
16
+
17
+ class APIConnectionError(Exception):
18
+ def __init__(self, *, message: str = "Connection error.", request: httpx.Request) -> None:
19
+ super().__init__(message, request, body=None)
20
+
21
+
22
+ class BadRequestError(APIStatusError):
23
+ status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
24
+
25
+
26
+ class AuthenticationError(APIStatusError):
27
+ status_code: Literal[401] = 401 # pyright: ignore[reportIncompatibleVariableOverride]
28
+
29
+
30
+ class PermissionDeniedError(APIStatusError):
31
+ status_code: Literal[403] = 403 # pyright: ignore[reportIncompatibleVariableOverride]
32
+
33
+
34
+ class NotFoundError(APIStatusError):
35
+ status_code: Literal[404] = 404 # pyright: ignore[reportIncompatibleVariableOverride]
36
+
37
+
38
+ class ConflictError(APIStatusError):
39
+ status_code: Literal[409] = 409 # pyright: ignore[reportIncompatibleVariableOverride]
40
+
41
+
42
+ class UnprocessableEntityError(APIStatusError):
43
+ status_code: Literal[422] = 422 # pyright: ignore[reportIncompatibleVariableOverride]
44
+
45
+
46
+ class RateLimitError(APIStatusError):
47
+ status_code: Literal[429] = 429 # pyright: ignore[reportIncompatibleVariableOverride]
48
+
49
+ class APITimeoutError(APIConnectionError):
50
+ def __init__(self, request: httpx.Request) -> None:
51
+ super().__init__(message="Request timed out.", request=request)
52
+
53
+
54
+ class BadRequestError(APIStatusError):
55
+ status_code: Literal[400] = 400 # pyright: ignore[reportIncompatibleVariableOverride]
lightrag/kg/redis_impl.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  from tqdm.asyncio import tqdm as tqdm_async
3
  from dataclasses import dataclass
4
- import aioredis
 
5
  from lightrag.utils import logger
6
  from lightrag.base import BaseKVStorage
7
  import json
@@ -11,7 +12,7 @@ import json
11
  class RedisKVStorage(BaseKVStorage):
12
  def __post_init__(self):
13
  redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
14
- self._redis = aioredis.from_url(redis_url, decode_responses=True)
15
  logger.info(f"Use Redis as KV {self.namespace}")
16
 
17
  async def all_keys(self) -> list[str]:
 
1
  import os
2
  from tqdm.asyncio import tqdm as tqdm_async
3
  from dataclasses import dataclass
4
+ # aioredis is a depricated library, replaced with redis
5
+ from redis.asyncio import Redis
6
  from lightrag.utils import logger
7
  from lightrag.base import BaseKVStorage
8
  import json
 
12
  class RedisKVStorage(BaseKVStorage):
13
  def __post_init__(self):
14
  redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
15
+ self._redis = Redis.from_url(redis_url, decode_responses=True)
16
  logger.info(f"Use Redis as KV {self.namespace}")
17
 
18
  async def all_keys(self) -> list[str]:
lightrag/lightrag.py CHANGED
@@ -6,10 +6,6 @@ from datetime import datetime
6
  from functools import partial
7
  from typing import Type, cast, Dict
8
 
9
- from .llm import (
10
- gpt_4o_mini_complete,
11
- openai_embedding,
12
- )
13
  from .operate import (
14
  chunking_by_token_size,
15
  extract_entities,
@@ -154,12 +150,12 @@ class LightRAG:
154
  )
155
 
156
  # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
157
- embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
158
  embedding_batch_num: int = 32
159
  embedding_func_max_async: int = 16
160
 
161
  # LLM
162
- llm_model_func: callable = gpt_4o_mini_complete # hf_model_complete#
163
  llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
164
  llm_model_max_token_size: int = 32768
165
  llm_model_max_async: int = 16
 
6
  from functools import partial
7
  from typing import Type, cast, Dict
8
 
 
 
 
 
9
  from .operate import (
10
  chunking_by_token_size,
11
  extract_entities,
 
150
  )
151
 
152
  # embedding_func: EmbeddingFunc = field(default_factory=lambda:hf_embedding)
153
+ embedding_func: EmbeddingFunc = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
154
  embedding_batch_num: int = 32
155
  embedding_func_max_async: int = 16
156
 
157
  # LLM
158
+ llm_model_func: callable = None # This must be set (we do want to separate llm from the corte, so no more default initialization)
159
  llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" # 'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
160
  llm_model_max_token_size: int = 32768
161
  llm_model_max_async: int = 16
lightrag/llm.py CHANGED
@@ -1,1211 +1,5 @@
1
- import base64
2
- import copy
3
- import json
4
- import os
5
- import re
6
- import struct
7
- from functools import lru_cache
8
- from typing import List, Dict, Callable, Any, Union, Optional
9
- import aioboto3
10
- import aiohttp
11
- import numpy as np
12
- import ollama
13
- import torch
14
- from openai import (
15
- AsyncOpenAI,
16
- APIConnectionError,
17
- RateLimitError,
18
- APITimeoutError,
19
- AsyncAzureOpenAI,
20
- )
21
  from pydantic import BaseModel, Field
22
- from tenacity import (
23
- retry,
24
- stop_after_attempt,
25
- wait_exponential,
26
- retry_if_exception_type,
27
- )
28
- from transformers import AutoTokenizer, AutoModelForCausalLM
29
-
30
- from .utils import (
31
- wrap_embedding_func_with_attrs,
32
- locate_json_string_body_from_string,
33
- safe_unicode_decode,
34
- logger,
35
- )
36
-
37
- import sys
38
-
39
- if sys.version_info < (3, 9):
40
- from typing import AsyncIterator
41
- else:
42
- from collections.abc import AsyncIterator
43
-
44
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
45
-
46
-
47
- @retry(
48
- stop=stop_after_attempt(3),
49
- wait=wait_exponential(multiplier=1, min=4, max=10),
50
- retry=retry_if_exception_type(
51
- (RateLimitError, APIConnectionError, APITimeoutError)
52
- ),
53
- )
54
- async def openai_complete_if_cache(
55
- model,
56
- prompt,
57
- system_prompt=None,
58
- history_messages=[],
59
- base_url=None,
60
- api_key=None,
61
- **kwargs,
62
- ) -> str:
63
- if api_key:
64
- os.environ["OPENAI_API_KEY"] = api_key
65
-
66
- openai_async_client = (
67
- AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
68
- )
69
- kwargs.pop("hashing_kv", None)
70
- kwargs.pop("keyword_extraction", None)
71
- messages = []
72
- if system_prompt:
73
- messages.append({"role": "system", "content": system_prompt})
74
- messages.extend(history_messages)
75
- messages.append({"role": "user", "content": prompt})
76
-
77
- # 添加日志输出
78
- logger.debug("===== Query Input to LLM =====")
79
- logger.debug(f"Query: {prompt}")
80
- logger.debug(f"System prompt: {system_prompt}")
81
- logger.debug("Full context:")
82
- if "response_format" in kwargs:
83
- response = await openai_async_client.beta.chat.completions.parse(
84
- model=model, messages=messages, **kwargs
85
- )
86
- else:
87
- response = await openai_async_client.chat.completions.create(
88
- model=model, messages=messages, **kwargs
89
- )
90
-
91
- if hasattr(response, "__aiter__"):
92
-
93
- async def inner():
94
- async for chunk in response:
95
- content = chunk.choices[0].delta.content
96
- if content is None:
97
- continue
98
- if r"\u" in content:
99
- content = safe_unicode_decode(content.encode("utf-8"))
100
- yield content
101
-
102
- return inner()
103
- else:
104
- content = response.choices[0].message.content
105
- if r"\u" in content:
106
- content = safe_unicode_decode(content.encode("utf-8"))
107
- return content
108
-
109
-
110
- @retry(
111
- stop=stop_after_attempt(3),
112
- wait=wait_exponential(multiplier=1, min=4, max=10),
113
- retry=retry_if_exception_type(
114
- (RateLimitError, APIConnectionError, APIConnectionError)
115
- ),
116
- )
117
- async def azure_openai_complete_if_cache(
118
- model,
119
- prompt,
120
- system_prompt=None,
121
- history_messages=[],
122
- base_url=None,
123
- api_key=None,
124
- api_version=None,
125
- **kwargs,
126
- ):
127
- if api_key:
128
- os.environ["AZURE_OPENAI_API_KEY"] = api_key
129
- if base_url:
130
- os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
131
- if api_version:
132
- os.environ["AZURE_OPENAI_API_VERSION"] = api_version
133
-
134
- openai_async_client = AsyncAzureOpenAI(
135
- azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
136
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
137
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
138
- )
139
- kwargs.pop("hashing_kv", None)
140
- messages = []
141
- if system_prompt:
142
- messages.append({"role": "system", "content": system_prompt})
143
- messages.extend(history_messages)
144
- if prompt is not None:
145
- messages.append({"role": "user", "content": prompt})
146
-
147
- if "response_format" in kwargs:
148
- response = await openai_async_client.beta.chat.completions.parse(
149
- model=model, messages=messages, **kwargs
150
- )
151
- else:
152
- response = await openai_async_client.chat.completions.create(
153
- model=model, messages=messages, **kwargs
154
- )
155
-
156
- if hasattr(response, "__aiter__"):
157
-
158
- async def inner():
159
- async for chunk in response:
160
- if len(chunk.choices) == 0:
161
- continue
162
- content = chunk.choices[0].delta.content
163
- if content is None:
164
- continue
165
- if r"\u" in content:
166
- content = safe_unicode_decode(content.encode("utf-8"))
167
- yield content
168
-
169
- return inner()
170
- else:
171
- content = response.choices[0].message.content
172
- if r"\u" in content:
173
- content = safe_unicode_decode(content.encode("utf-8"))
174
- return content
175
-
176
-
177
- class BedrockError(Exception):
178
- """Generic error for issues related to Amazon Bedrock"""
179
-
180
-
181
- @retry(
182
- stop=stop_after_attempt(5),
183
- wait=wait_exponential(multiplier=1, max=60),
184
- retry=retry_if_exception_type((BedrockError)),
185
- )
186
- async def bedrock_complete_if_cache(
187
- model,
188
- prompt,
189
- system_prompt=None,
190
- history_messages=[],
191
- aws_access_key_id=None,
192
- aws_secret_access_key=None,
193
- aws_session_token=None,
194
- **kwargs,
195
- ) -> str:
196
- os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
197
- "AWS_ACCESS_KEY_ID", aws_access_key_id
198
- )
199
- os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
200
- "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
201
- )
202
- os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
203
- "AWS_SESSION_TOKEN", aws_session_token
204
- )
205
- kwargs.pop("hashing_kv", None)
206
- # Fix message history format
207
- messages = []
208
- for history_message in history_messages:
209
- message = copy.copy(history_message)
210
- message["content"] = [{"text": message["content"]}]
211
- messages.append(message)
212
-
213
- # Add user prompt
214
- messages.append({"role": "user", "content": [{"text": prompt}]})
215
-
216
- # Initialize Converse API arguments
217
- args = {"modelId": model, "messages": messages}
218
-
219
- # Define system prompt
220
- if system_prompt:
221
- args["system"] = [{"text": system_prompt}]
222
-
223
- # Map and set up inference parameters
224
- inference_params_map = {
225
- "max_tokens": "maxTokens",
226
- "top_p": "topP",
227
- "stop_sequences": "stopSequences",
228
- }
229
- if inference_params := list(
230
- set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
231
- ):
232
- args["inferenceConfig"] = {}
233
- for param in inference_params:
234
- args["inferenceConfig"][inference_params_map.get(param, param)] = (
235
- kwargs.pop(param)
236
- )
237
-
238
- # Call model via Converse API
239
- session = aioboto3.Session()
240
- async with session.client("bedrock-runtime") as bedrock_async_client:
241
- try:
242
- response = await bedrock_async_client.converse(**args, **kwargs)
243
- except Exception as e:
244
- raise BedrockError(e)
245
-
246
- return response["output"]["message"]["content"][0]["text"]
247
-
248
-
249
- @lru_cache(maxsize=1)
250
- def initialize_hf_model(model_name):
251
- hf_tokenizer = AutoTokenizer.from_pretrained(
252
- model_name, device_map="auto", trust_remote_code=True
253
- )
254
- hf_model = AutoModelForCausalLM.from_pretrained(
255
- model_name, device_map="auto", trust_remote_code=True
256
- )
257
- if hf_tokenizer.pad_token is None:
258
- hf_tokenizer.pad_token = hf_tokenizer.eos_token
259
-
260
- return hf_model, hf_tokenizer
261
-
262
-
263
- @retry(
264
- stop=stop_after_attempt(3),
265
- wait=wait_exponential(multiplier=1, min=4, max=10),
266
- retry=retry_if_exception_type(
267
- (RateLimitError, APIConnectionError, APITimeoutError)
268
- ),
269
- )
270
- async def hf_model_if_cache(
271
- model,
272
- prompt,
273
- system_prompt=None,
274
- history_messages=[],
275
- **kwargs,
276
- ) -> str:
277
- model_name = model
278
- hf_model, hf_tokenizer = initialize_hf_model(model_name)
279
- messages = []
280
- if system_prompt:
281
- messages.append({"role": "system", "content": system_prompt})
282
- messages.extend(history_messages)
283
- messages.append({"role": "user", "content": prompt})
284
- kwargs.pop("hashing_kv", None)
285
- input_prompt = ""
286
- try:
287
- input_prompt = hf_tokenizer.apply_chat_template(
288
- messages, tokenize=False, add_generation_prompt=True
289
- )
290
- except Exception:
291
- try:
292
- ori_message = copy.deepcopy(messages)
293
- if messages[0]["role"] == "system":
294
- messages[1]["content"] = (
295
- "<system>"
296
- + messages[0]["content"]
297
- + "</system>\n"
298
- + messages[1]["content"]
299
- )
300
- messages = messages[1:]
301
- input_prompt = hf_tokenizer.apply_chat_template(
302
- messages, tokenize=False, add_generation_prompt=True
303
- )
304
- except Exception:
305
- len_message = len(ori_message)
306
- for msgid in range(len_message):
307
- input_prompt = (
308
- input_prompt
309
- + "<"
310
- + ori_message[msgid]["role"]
311
- + ">"
312
- + ori_message[msgid]["content"]
313
- + "</"
314
- + ori_message[msgid]["role"]
315
- + ">\n"
316
- )
317
-
318
- input_ids = hf_tokenizer(
319
- input_prompt, return_tensors="pt", padding=True, truncation=True
320
- ).to("cuda")
321
- inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
322
- output = hf_model.generate(
323
- **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
324
- )
325
- response_text = hf_tokenizer.decode(
326
- output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
327
- )
328
-
329
- return response_text
330
-
331
-
332
- @retry(
333
- stop=stop_after_attempt(3),
334
- wait=wait_exponential(multiplier=1, min=4, max=10),
335
- retry=retry_if_exception_type(
336
- (RateLimitError, APIConnectionError, APITimeoutError)
337
- ),
338
- )
339
- async def ollama_model_if_cache(
340
- model,
341
- prompt,
342
- system_prompt=None,
343
- history_messages=[],
344
- **kwargs,
345
- ) -> Union[str, AsyncIterator[str]]:
346
- stream = True if kwargs.get("stream") else False
347
- kwargs.pop("max_tokens", None)
348
- # kwargs.pop("response_format", None) # allow json
349
- host = kwargs.pop("host", None)
350
- timeout = kwargs.pop("timeout", None)
351
- kwargs.pop("hashing_kv", None)
352
- api_key = kwargs.pop("api_key", None)
353
- headers = (
354
- {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
355
- if api_key
356
- else {"Content-Type": "application/json"}
357
- )
358
- ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
359
- messages = []
360
- if system_prompt:
361
- messages.append({"role": "system", "content": system_prompt})
362
- messages.extend(history_messages)
363
- messages.append({"role": "user", "content": prompt})
364
-
365
- response = await ollama_client.chat(model=model, messages=messages, **kwargs)
366
- if stream:
367
- """cannot cache stream response"""
368
-
369
- async def inner():
370
- async for chunk in response:
371
- yield chunk["message"]["content"]
372
-
373
- return inner()
374
- else:
375
- return response["message"]["content"]
376
-
377
-
378
- async def lollms_model_if_cache(
379
- model,
380
- prompt,
381
- system_prompt=None,
382
- history_messages=[],
383
- base_url="http://localhost:9600",
384
- **kwargs,
385
- ) -> Union[str, AsyncIterator[str]]:
386
- """Client implementation for lollms generation."""
387
-
388
- stream = True if kwargs.get("stream") else False
389
- api_key = kwargs.pop("api_key", None)
390
- headers = (
391
- {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
392
- if api_key
393
- else {"Content-Type": "application/json"}
394
- )
395
-
396
- # Extract lollms specific parameters
397
- request_data = {
398
- "prompt": prompt,
399
- "model_name": model,
400
- "personality": kwargs.get("personality", -1),
401
- "n_predict": kwargs.get("n_predict", None),
402
- "stream": stream,
403
- "temperature": kwargs.get("temperature", 0.1),
404
- "top_k": kwargs.get("top_k", 50),
405
- "top_p": kwargs.get("top_p", 0.95),
406
- "repeat_penalty": kwargs.get("repeat_penalty", 0.8),
407
- "repeat_last_n": kwargs.get("repeat_last_n", 40),
408
- "seed": kwargs.get("seed", None),
409
- "n_threads": kwargs.get("n_threads", 8),
410
- }
411
-
412
- # Prepare the full prompt including history
413
- full_prompt = ""
414
- if system_prompt:
415
- full_prompt += f"{system_prompt}\n"
416
- for msg in history_messages:
417
- full_prompt += f"{msg['role']}: {msg['content']}\n"
418
- full_prompt += prompt
419
-
420
- request_data["prompt"] = full_prompt
421
- timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
422
-
423
- async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
424
- if stream:
425
-
426
- async def inner():
427
- async with session.post(
428
- f"{base_url}/lollms_generate", json=request_data
429
- ) as response:
430
- async for line in response.content:
431
- yield line.decode().strip()
432
-
433
- return inner()
434
- else:
435
- async with session.post(
436
- f"{base_url}/lollms_generate", json=request_data
437
- ) as response:
438
- return await response.text()
439
-
440
-
441
- @lru_cache(maxsize=1)
442
- def initialize_lmdeploy_pipeline(
443
- model,
444
- tp=1,
445
- chat_template=None,
446
- log_level="WARNING",
447
- model_format="hf",
448
- quant_policy=0,
449
- ):
450
- from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
451
-
452
- lmdeploy_pipe = pipeline(
453
- model_path=model,
454
- backend_config=TurbomindEngineConfig(
455
- tp=tp, model_format=model_format, quant_policy=quant_policy
456
- ),
457
- chat_template_config=(
458
- ChatTemplateConfig(model_name=chat_template) if chat_template else None
459
- ),
460
- log_level="WARNING",
461
- )
462
- return lmdeploy_pipe
463
-
464
-
465
- @retry(
466
- stop=stop_after_attempt(3),
467
- wait=wait_exponential(multiplier=1, min=4, max=10),
468
- retry=retry_if_exception_type(
469
- (RateLimitError, APIConnectionError, APITimeoutError)
470
- ),
471
- )
472
- async def lmdeploy_model_if_cache(
473
- model,
474
- prompt,
475
- system_prompt=None,
476
- history_messages=[],
477
- chat_template=None,
478
- model_format="hf",
479
- quant_policy=0,
480
- **kwargs,
481
- ) -> str:
482
- """
483
- Args:
484
- model (str): The path to the model.
485
- It could be one of the following options:
486
- - i) A local directory path of a turbomind model which is
487
- converted by `lmdeploy convert` command or download
488
- from ii) and iii).
489
- - ii) The model_id of a lmdeploy-quantized model hosted
490
- inside a model repo on huggingface.co, such as
491
- "InternLM/internlm-chat-20b-4bit",
492
- "lmdeploy/llama2-chat-70b-4bit", etc.
493
- - iii) The model_id of a model hosted inside a model repo
494
- on huggingface.co, such as "internlm/internlm-chat-7b",
495
- "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
496
- and so on.
497
- chat_template (str): needed when model is a pytorch model on
498
- huggingface.co, such as "internlm-chat-7b",
499
- "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
500
- and when the model name of local path did not match the original model name in HF.
501
- tp (int): tensor parallel
502
- prompt (Union[str, List[str]]): input texts to be completed.
503
- do_preprocess (bool): whether pre-process the messages. Default to
504
- True, which means chat_template will be applied.
505
- skip_special_tokens (bool): Whether or not to remove special tokens
506
- in the decoding. Default to be True.
507
- do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
508
- Default to be False, which means greedy decoding will be applied.
509
- """
510
- try:
511
- import lmdeploy
512
- from lmdeploy import version_info, GenerationConfig
513
- except Exception:
514
- raise ImportError("Please install lmdeploy before initialize lmdeploy backend.")
515
- kwargs.pop("hashing_kv", None)
516
- kwargs.pop("response_format", None)
517
- max_new_tokens = kwargs.pop("max_tokens", 512)
518
- tp = kwargs.pop("tp", 1)
519
- skip_special_tokens = kwargs.pop("skip_special_tokens", True)
520
- do_preprocess = kwargs.pop("do_preprocess", True)
521
- do_sample = kwargs.pop("do_sample", False)
522
- gen_params = kwargs
523
-
524
- version = version_info
525
- if do_sample is not None and version < (0, 6, 0):
526
- raise RuntimeError(
527
- "`do_sample` parameter is not supported by lmdeploy until "
528
- f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
529
- )
530
- else:
531
- do_sample = True
532
- gen_params.update(do_sample=do_sample)
533
-
534
- lmdeploy_pipe = initialize_lmdeploy_pipeline(
535
- model=model,
536
- tp=tp,
537
- chat_template=chat_template,
538
- model_format=model_format,
539
- quant_policy=quant_policy,
540
- log_level="WARNING",
541
- )
542
-
543
- messages = []
544
- if system_prompt:
545
- messages.append({"role": "system", "content": system_prompt})
546
-
547
- messages.extend(history_messages)
548
- messages.append({"role": "user", "content": prompt})
549
-
550
- gen_config = GenerationConfig(
551
- skip_special_tokens=skip_special_tokens,
552
- max_new_tokens=max_new_tokens,
553
- **gen_params,
554
- )
555
-
556
- response = ""
557
- async for res in lmdeploy_pipe.generate(
558
- messages,
559
- gen_config=gen_config,
560
- do_preprocess=do_preprocess,
561
- stream_response=False,
562
- session_id=1,
563
- ):
564
- response += res.response
565
- return response
566
-
567
-
568
- class GPTKeywordExtractionFormat(BaseModel):
569
- high_level_keywords: List[str]
570
- low_level_keywords: List[str]
571
-
572
-
573
- async def openai_complete(
574
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
575
- ) -> Union[str, AsyncIterator[str]]:
576
- keyword_extraction = kwargs.pop("keyword_extraction", None)
577
- if keyword_extraction:
578
- kwargs["response_format"] = "json"
579
- model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
580
- return await openai_complete_if_cache(
581
- model_name,
582
- prompt,
583
- system_prompt=system_prompt,
584
- history_messages=history_messages,
585
- **kwargs,
586
- )
587
-
588
-
589
- async def gpt_4o_complete(
590
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
591
- ) -> str:
592
- keyword_extraction = kwargs.pop("keyword_extraction", None)
593
- if keyword_extraction:
594
- kwargs["response_format"] = GPTKeywordExtractionFormat
595
- return await openai_complete_if_cache(
596
- "gpt-4o",
597
- prompt,
598
- system_prompt=system_prompt,
599
- history_messages=history_messages,
600
- **kwargs,
601
- )
602
-
603
-
604
- async def gpt_4o_mini_complete(
605
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
606
- ) -> str:
607
- keyword_extraction = kwargs.pop("keyword_extraction", None)
608
- if keyword_extraction:
609
- kwargs["response_format"] = GPTKeywordExtractionFormat
610
- return await openai_complete_if_cache(
611
- "gpt-4o-mini",
612
- prompt,
613
- system_prompt=system_prompt,
614
- history_messages=history_messages,
615
- **kwargs,
616
- )
617
-
618
-
619
- async def nvidia_openai_complete(
620
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
621
- ) -> str:
622
- keyword_extraction = kwargs.pop("keyword_extraction", None)
623
- result = await openai_complete_if_cache(
624
- "nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
625
- prompt,
626
- system_prompt=system_prompt,
627
- history_messages=history_messages,
628
- base_url="https://integrate.api.nvidia.com/v1",
629
- **kwargs,
630
- )
631
- if keyword_extraction: # TODO: use JSON API
632
- return locate_json_string_body_from_string(result)
633
- return result
634
-
635
-
636
- async def azure_openai_complete(
637
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
638
- ) -> str:
639
- keyword_extraction = kwargs.pop("keyword_extraction", None)
640
- result = await azure_openai_complete_if_cache(
641
- os.getenv("LLM_MODEL", "gpt-4o-mini"),
642
- prompt,
643
- system_prompt=system_prompt,
644
- history_messages=history_messages,
645
- **kwargs,
646
- )
647
- if keyword_extraction: # TODO: use JSON API
648
- return locate_json_string_body_from_string(result)
649
- return result
650
-
651
-
652
- async def bedrock_complete(
653
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
654
- ) -> str:
655
- keyword_extraction = kwargs.pop("keyword_extraction", None)
656
- result = await bedrock_complete_if_cache(
657
- "anthropic.claude-3-haiku-20240307-v1:0",
658
- prompt,
659
- system_prompt=system_prompt,
660
- history_messages=history_messages,
661
- **kwargs,
662
- )
663
- if keyword_extraction: # TODO: use JSON API
664
- return locate_json_string_body_from_string(result)
665
- return result
666
-
667
-
668
- async def hf_model_complete(
669
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
670
- ) -> str:
671
- keyword_extraction = kwargs.pop("keyword_extraction", None)
672
- model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
673
- result = await hf_model_if_cache(
674
- model_name,
675
- prompt,
676
- system_prompt=system_prompt,
677
- history_messages=history_messages,
678
- **kwargs,
679
- )
680
- if keyword_extraction: # TODO: use JSON API
681
- return locate_json_string_body_from_string(result)
682
- return result
683
-
684
-
685
- async def ollama_model_complete(
686
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
687
- ) -> Union[str, AsyncIterator[str]]:
688
- keyword_extraction = kwargs.pop("keyword_extraction", None)
689
- if keyword_extraction:
690
- kwargs["format"] = "json"
691
- model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
692
- return await ollama_model_if_cache(
693
- model_name,
694
- prompt,
695
- system_prompt=system_prompt,
696
- history_messages=history_messages,
697
- **kwargs,
698
- )
699
-
700
-
701
- async def lollms_model_complete(
702
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
703
- ) -> Union[str, AsyncIterator[str]]:
704
- """Complete function for lollms model generation."""
705
-
706
- # Extract and remove keyword_extraction from kwargs if present
707
- keyword_extraction = kwargs.pop("keyword_extraction", None)
708
-
709
- # Get model name from config
710
- model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
711
-
712
- # If keyword extraction is needed, we might need to modify the prompt
713
- # or add specific parameters for JSON output (if lollms supports it)
714
- if keyword_extraction:
715
- # Note: You might need to adjust this based on how lollms handles structured output
716
- pass
717
-
718
- return await lollms_model_if_cache(
719
- model_name,
720
- prompt,
721
- system_prompt=system_prompt,
722
- history_messages=history_messages,
723
- **kwargs,
724
- )
725
-
726
-
727
- @retry(
728
- stop=stop_after_attempt(3),
729
- wait=wait_exponential(multiplier=1, min=4, max=10),
730
- retry=retry_if_exception_type(
731
- (RateLimitError, APIConnectionError, APITimeoutError)
732
- ),
733
- )
734
- async def zhipu_complete_if_cache(
735
- prompt: Union[str, List[Dict[str, str]]],
736
- model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
737
- api_key: Optional[str] = None,
738
- system_prompt: Optional[str] = None,
739
- history_messages: List[Dict[str, str]] = [],
740
- **kwargs,
741
- ) -> str:
742
- # dynamically load ZhipuAI
743
- try:
744
- from zhipuai import ZhipuAI
745
- except ImportError:
746
- raise ImportError("Please install zhipuai before initialize zhipuai backend.")
747
-
748
- if api_key:
749
- client = ZhipuAI(api_key=api_key)
750
- else:
751
- # please set ZHIPUAI_API_KEY in your environment
752
- # os.environ["ZHIPUAI_API_KEY"]
753
- client = ZhipuAI()
754
-
755
- messages = []
756
-
757
- if not system_prompt:
758
- system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
759
-
760
- # Add system prompt if provided
761
- if system_prompt:
762
- messages.append({"role": "system", "content": system_prompt})
763
- messages.extend(history_messages)
764
- messages.append({"role": "user", "content": prompt})
765
-
766
- # Add debug logging
767
- logger.debug("===== Query Input to LLM =====")
768
- logger.debug(f"Query: {prompt}")
769
- logger.debug(f"System prompt: {system_prompt}")
770
-
771
- # Remove unsupported kwargs
772
- kwargs = {
773
- k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
774
- }
775
-
776
- response = client.chat.completions.create(model=model, messages=messages, **kwargs)
777
-
778
- return response.choices[0].message.content
779
-
780
-
781
- async def zhipu_complete(
782
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
783
- ):
784
- # Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
785
- keyword_extraction = kwargs.pop("keyword_extraction", None)
786
-
787
- if keyword_extraction:
788
- # Add a system prompt to guide the model to return JSON format
789
- extraction_prompt = """You are a helpful assistant that extracts keywords from text.
790
- Please analyze the content and extract two types of keywords:
791
- 1. High-level keywords: Important concepts and main themes
792
- 2. Low-level keywords: Specific details and supporting elements
793
-
794
- Return your response in this exact JSON format:
795
- {
796
- "high_level_keywords": ["keyword1", "keyword2"],
797
- "low_level_keywords": ["keyword1", "keyword2", "keyword3"]
798
- }
799
-
800
- Only return the JSON, no other text."""
801
-
802
- # Combine with existing system prompt if any
803
- if system_prompt:
804
- system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
805
- else:
806
- system_prompt = extraction_prompt
807
-
808
- try:
809
- response = await zhipu_complete_if_cache(
810
- prompt=prompt,
811
- system_prompt=system_prompt,
812
- history_messages=history_messages,
813
- **kwargs,
814
- )
815
-
816
- # Try to parse as JSON
817
- try:
818
- data = json.loads(response)
819
- return GPTKeywordExtractionFormat(
820
- high_level_keywords=data.get("high_level_keywords", []),
821
- low_level_keywords=data.get("low_level_keywords", []),
822
- )
823
- except json.JSONDecodeError:
824
- # If direct JSON parsing fails, try to extract JSON from text
825
- match = re.search(r"\{[\s\S]*\}", response)
826
- if match:
827
- try:
828
- data = json.loads(match.group())
829
- return GPTKeywordExtractionFormat(
830
- high_level_keywords=data.get("high_level_keywords", []),
831
- low_level_keywords=data.get("low_level_keywords", []),
832
- )
833
- except json.JSONDecodeError:
834
- pass
835
-
836
- # If all parsing fails, log warning and return empty format
837
- logger.warning(
838
- f"Failed to parse keyword extraction response: {response}"
839
- )
840
- return GPTKeywordExtractionFormat(
841
- high_level_keywords=[], low_level_keywords=[]
842
- )
843
- except Exception as e:
844
- logger.error(f"Error during keyword extraction: {str(e)}")
845
- return GPTKeywordExtractionFormat(
846
- high_level_keywords=[], low_level_keywords=[]
847
- )
848
- else:
849
- # For non-keyword-extraction, just return the raw response string
850
- return await zhipu_complete_if_cache(
851
- prompt=prompt,
852
- system_prompt=system_prompt,
853
- history_messages=history_messages,
854
- **kwargs,
855
- )
856
-
857
-
858
- @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
859
- @retry(
860
- stop=stop_after_attempt(3),
861
- wait=wait_exponential(multiplier=1, min=4, max=60),
862
- retry=retry_if_exception_type(
863
- (RateLimitError, APIConnectionError, APITimeoutError)
864
- ),
865
- )
866
- async def zhipu_embedding(
867
- texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
868
- ) -> np.ndarray:
869
- # dynamically load ZhipuAI
870
- try:
871
- from zhipuai import ZhipuAI
872
- except ImportError:
873
- raise ImportError("Please install zhipuai before initialize zhipuai backend.")
874
- if api_key:
875
- client = ZhipuAI(api_key=api_key)
876
- else:
877
- # please set ZHIPUAI_API_KEY in your environment
878
- # os.environ["ZHIPUAI_API_KEY"]
879
- client = ZhipuAI()
880
-
881
- # Convert single text to list if needed
882
- if isinstance(texts, str):
883
- texts = [texts]
884
-
885
- embeddings = []
886
- for text in texts:
887
- try:
888
- response = client.embeddings.create(model=model, input=[text], **kwargs)
889
- embeddings.append(response.data[0].embedding)
890
- except Exception as e:
891
- raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
892
-
893
- return np.array(embeddings)
894
-
895
-
896
- @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
897
- @retry(
898
- stop=stop_after_attempt(3),
899
- wait=wait_exponential(multiplier=1, min=4, max=60),
900
- retry=retry_if_exception_type(
901
- (RateLimitError, APIConnectionError, APITimeoutError)
902
- ),
903
- )
904
- async def openai_embedding(
905
- texts: list[str],
906
- model: str = "text-embedding-3-small",
907
- base_url: str = None,
908
- api_key: str = None,
909
- ) -> np.ndarray:
910
- if api_key:
911
- os.environ["OPENAI_API_KEY"] = api_key
912
-
913
- openai_async_client = (
914
- AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
915
- )
916
- response = await openai_async_client.embeddings.create(
917
- model=model, input=texts, encoding_format="float"
918
- )
919
- return np.array([dp.embedding for dp in response.data])
920
-
921
-
922
- async def fetch_data(url, headers, data):
923
- async with aiohttp.ClientSession() as session:
924
- async with session.post(url, headers=headers, json=data) as response:
925
- response_json = await response.json()
926
- data_list = response_json.get("data", [])
927
- return data_list
928
-
929
-
930
- async def jina_embedding(
931
- texts: list[str],
932
- dimensions: int = 1024,
933
- late_chunking: bool = False,
934
- base_url: str = None,
935
- api_key: str = None,
936
- ) -> np.ndarray:
937
- if api_key:
938
- os.environ["JINA_API_KEY"] = api_key
939
- url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
940
- headers = {
941
- "Content-Type": "application/json",
942
- "Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
943
- }
944
- data = {
945
- "model": "jina-embeddings-v3",
946
- "normalized": True,
947
- "embedding_type": "float",
948
- "dimensions": f"{dimensions}",
949
- "late_chunking": late_chunking,
950
- "input": texts,
951
- }
952
- data_list = await fetch_data(url, headers, data)
953
- return np.array([dp["embedding"] for dp in data_list])
954
-
955
-
956
- @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
957
- @retry(
958
- stop=stop_after_attempt(3),
959
- wait=wait_exponential(multiplier=1, min=4, max=60),
960
- retry=retry_if_exception_type(
961
- (RateLimitError, APIConnectionError, APITimeoutError)
962
- ),
963
- )
964
- async def nvidia_openai_embedding(
965
- texts: list[str],
966
- model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
967
- # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
968
- base_url: str = "https://integrate.api.nvidia.com/v1",
969
- api_key: str = None,
970
- input_type: str = "passage", # query for retrieval, passage for embedding
971
- trunc: str = "NONE", # NONE or START or END
972
- encode: str = "float", # float or base64
973
- ) -> np.ndarray:
974
- if api_key:
975
- os.environ["OPENAI_API_KEY"] = api_key
976
-
977
- openai_async_client = (
978
- AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
979
- )
980
- response = await openai_async_client.embeddings.create(
981
- model=model,
982
- input=texts,
983
- encoding_format=encode,
984
- extra_body={"input_type": input_type, "truncate": trunc},
985
- )
986
- return np.array([dp.embedding for dp in response.data])
987
-
988
-
989
- @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
990
- @retry(
991
- stop=stop_after_attempt(3),
992
- wait=wait_exponential(multiplier=1, min=4, max=10),
993
- retry=retry_if_exception_type(
994
- (RateLimitError, APIConnectionError, APITimeoutError)
995
- ),
996
- )
997
- async def azure_openai_embedding(
998
- texts: list[str],
999
- model: str = "text-embedding-3-small",
1000
- base_url: str = None,
1001
- api_key: str = None,
1002
- api_version: str = None,
1003
- ) -> np.ndarray:
1004
- if api_key:
1005
- os.environ["AZURE_OPENAI_API_KEY"] = api_key
1006
- if base_url:
1007
- os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
1008
- if api_version:
1009
- os.environ["AZURE_OPENAI_API_VERSION"] = api_version
1010
-
1011
- openai_async_client = AsyncAzureOpenAI(
1012
- azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
1013
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
1014
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
1015
- )
1016
-
1017
- response = await openai_async_client.embeddings.create(
1018
- model=model, input=texts, encoding_format="float"
1019
- )
1020
- return np.array([dp.embedding for dp in response.data])
1021
-
1022
-
1023
- @retry(
1024
- stop=stop_after_attempt(3),
1025
- wait=wait_exponential(multiplier=1, min=4, max=60),
1026
- retry=retry_if_exception_type(
1027
- (RateLimitError, APIConnectionError, APITimeoutError)
1028
- ),
1029
- )
1030
- async def siliconcloud_embedding(
1031
- texts: list[str],
1032
- model: str = "netease-youdao/bce-embedding-base_v1",
1033
- base_url: str = "https://api.siliconflow.cn/v1/embeddings",
1034
- max_token_size: int = 512,
1035
- api_key: str = None,
1036
- ) -> np.ndarray:
1037
- if api_key and not api_key.startswith("Bearer "):
1038
- api_key = "Bearer " + api_key
1039
-
1040
- headers = {"Authorization": api_key, "Content-Type": "application/json"}
1041
-
1042
- truncate_texts = [text[0:max_token_size] for text in texts]
1043
-
1044
- payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
1045
-
1046
- base64_strings = []
1047
- async with aiohttp.ClientSession() as session:
1048
- async with session.post(base_url, headers=headers, json=payload) as response:
1049
- content = await response.json()
1050
- if "code" in content:
1051
- raise ValueError(content)
1052
- base64_strings = [item["embedding"] for item in content["data"]]
1053
-
1054
- embeddings = []
1055
- for string in base64_strings:
1056
- decode_bytes = base64.b64decode(string)
1057
- n = len(decode_bytes) // 4
1058
- float_array = struct.unpack("<" + "f" * n, decode_bytes)
1059
- embeddings.append(float_array)
1060
- return np.array(embeddings)
1061
-
1062
-
1063
- # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
1064
- # @retry(
1065
- # stop=stop_after_attempt(3),
1066
- # wait=wait_exponential(multiplier=1, min=4, max=10),
1067
- # retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
1068
- # )
1069
- async def bedrock_embedding(
1070
- texts: list[str],
1071
- model: str = "amazon.titan-embed-text-v2:0",
1072
- aws_access_key_id=None,
1073
- aws_secret_access_key=None,
1074
- aws_session_token=None,
1075
- ) -> np.ndarray:
1076
- os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
1077
- "AWS_ACCESS_KEY_ID", aws_access_key_id
1078
- )
1079
- os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
1080
- "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
1081
- )
1082
- os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
1083
- "AWS_SESSION_TOKEN", aws_session_token
1084
- )
1085
-
1086
- session = aioboto3.Session()
1087
- async with session.client("bedrock-runtime") as bedrock_async_client:
1088
- if (model_provider := model.split(".")[0]) == "amazon":
1089
- embed_texts = []
1090
- for text in texts:
1091
- if "v2" in model:
1092
- body = json.dumps(
1093
- {
1094
- "inputText": text,
1095
- # 'dimensions': embedding_dim,
1096
- "embeddingTypes": ["float"],
1097
- }
1098
- )
1099
- elif "v1" in model:
1100
- body = json.dumps({"inputText": text})
1101
- else:
1102
- raise ValueError(f"Model {model} is not supported!")
1103
-
1104
- response = await bedrock_async_client.invoke_model(
1105
- modelId=model,
1106
- body=body,
1107
- accept="application/json",
1108
- contentType="application/json",
1109
- )
1110
-
1111
- response_body = await response.get("body").json()
1112
-
1113
- embed_texts.append(response_body["embedding"])
1114
- elif model_provider == "cohere":
1115
- body = json.dumps(
1116
- {"texts": texts, "input_type": "search_document", "truncate": "NONE"}
1117
- )
1118
-
1119
- response = await bedrock_async_client.invoke_model(
1120
- model=model,
1121
- body=body,
1122
- accept="application/json",
1123
- contentType="application/json",
1124
- )
1125
-
1126
- response_body = json.loads(response.get("body").read())
1127
-
1128
- embed_texts = response_body["embeddings"]
1129
- else:
1130
- raise ValueError(f"Model provider '{model_provider}' is not supported!")
1131
-
1132
- return np.array(embed_texts)
1133
-
1134
-
1135
- async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
1136
- device = next(embed_model.parameters()).device
1137
- input_ids = tokenizer(
1138
- texts, return_tensors="pt", padding=True, truncation=True
1139
- ).input_ids.to(device)
1140
- with torch.no_grad():
1141
- outputs = embed_model(input_ids)
1142
- embeddings = outputs.last_hidden_state.mean(dim=1)
1143
- if embeddings.dtype == torch.bfloat16:
1144
- return embeddings.detach().to(torch.float32).cpu().numpy()
1145
- else:
1146
- return embeddings.detach().cpu().numpy()
1147
-
1148
-
1149
- async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
1150
- """
1151
- Deprecated in favor of `embed`.
1152
- """
1153
- embed_text = []
1154
- ollama_client = ollama.Client(**kwargs)
1155
- for text in texts:
1156
- data = ollama_client.embeddings(model=embed_model, prompt=text)
1157
- embed_text.append(data["embedding"])
1158
-
1159
- return embed_text
1160
-
1161
-
1162
- async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
1163
- api_key = kwargs.pop("api_key", None)
1164
- headers = (
1165
- {"Content-Type": "application/json", "Authorization": api_key}
1166
- if api_key
1167
- else {"Content-Type": "application/json"}
1168
- )
1169
- kwargs["headers"] = headers
1170
- ollama_client = ollama.Client(**kwargs)
1171
- data = ollama_client.embed(model=embed_model, input=texts)
1172
- return data["embeddings"]
1173
-
1174
-
1175
- async def lollms_embed(
1176
- texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
1177
- ) -> np.ndarray:
1178
- """
1179
- Generate embeddings for a list of texts using lollms server.
1180
-
1181
- Args:
1182
- texts: List of strings to embed
1183
- embed_model: Model name (not used directly as lollms uses configured vectorizer)
1184
- base_url: URL of the lollms server
1185
- **kwargs: Additional arguments passed to the request
1186
-
1187
- Returns:
1188
- np.ndarray: Array of embeddings
1189
- """
1190
- api_key = kwargs.pop("api_key", None)
1191
- headers = (
1192
- {"Content-Type": "application/json", "Authorization": api_key}
1193
- if api_key
1194
- else {"Content-Type": "application/json"}
1195
- )
1196
- async with aiohttp.ClientSession(headers=headers) as session:
1197
- embeddings = []
1198
- for text in texts:
1199
- request_data = {"text": text}
1200
-
1201
- async with session.post(
1202
- f"{base_url}/lollms_embed",
1203
- json=request_data,
1204
- ) as response:
1205
- result = await response.json()
1206
- embeddings.append(result["vector"])
1207
-
1208
- return np.array(embeddings)
1209
 
1210
 
1211
  class Model(BaseModel):
@@ -1293,6 +87,8 @@ if __name__ == "__main__":
1293
  import asyncio
1294
 
1295
  async def main():
 
 
1296
  result = await gpt_4o_mini_complete("How are you?")
1297
  print(result)
1298
 
 
1
+ from typing import List, Dict, Callable, Any
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from pydantic import BaseModel, Field
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  class Model(BaseModel):
 
87
  import asyncio
88
 
89
  async def main():
90
+ from lightrag.llm.openai import gpt_4o_mini_complete
91
+
92
  result = await gpt_4o_mini_complete("How are you?")
93
  print(result)
94
 
lightrag/llm/__init__.py ADDED
File without changes
lightrag/llm/azure_openai.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Azure OpenAI LLM Interface Module
3
+ ==========================
4
+
5
+ This module provides interfaces for interacting with aure openai's language models,
6
+ including text generation and embedding capabilities.
7
+
8
+ Author: Lightrag team
9
+ Created: 2024-01-24
10
+ License: MIT License
11
+
12
+ Copyright (c) 2024 Lightrag
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ Version: 1.0.0
22
+
23
+ Change Log:
24
+ - 1.0.0 (2024-01-24): Initial release
25
+ * Added async chat completion support
26
+ * Added embedding generation
27
+ * Added stream response capability
28
+
29
+ Dependencies:
30
+ - openai
31
+ - numpy
32
+ - pipmaster
33
+ - Python >= 3.10
34
+
35
+ Usage:
36
+ from llm_interfaces.azure_openai import azure_openai_model_complete, azure_openai_embed
37
+ """
38
+
39
+ __version__ = "1.0.0"
40
+ __author__ = "lightrag Team"
41
+ __status__ = "Production"
42
+
43
+
44
+ import os
45
+ import pipmaster as pm # Pipmaster for dynamic library install
46
+
47
+ # install specific modules
48
+ if not pm.is_installed("openai"):
49
+ pm.install("openai")
50
+ if not pm.is_installed("tenacity"):
51
+ pm.install("tenacity")
52
+
53
+ from openai import (
54
+ AsyncAzureOpenAI,
55
+ APIConnectionError,
56
+ RateLimitError,
57
+ APITimeoutError,
58
+ )
59
+ from tenacity import (
60
+ retry,
61
+ stop_after_attempt,
62
+ wait_exponential,
63
+ retry_if_exception_type,
64
+ )
65
+
66
+ from lightrag.utils import (
67
+ wrap_embedding_func_with_attrs,
68
+ locate_json_string_body_from_string,
69
+ safe_unicode_decode,
70
+ )
71
+
72
+ import numpy as np
73
+
74
+ @retry(
75
+ stop=stop_after_attempt(3),
76
+ wait=wait_exponential(multiplier=1, min=4, max=10),
77
+ retry=retry_if_exception_type(
78
+ (RateLimitError, APIConnectionError, APIConnectionError)
79
+ ),
80
+ )
81
+ async def azure_openai_complete_if_cache(
82
+ model,
83
+ prompt,
84
+ system_prompt=None,
85
+ history_messages=[],
86
+ base_url=None,
87
+ api_key=None,
88
+ api_version=None,
89
+ **kwargs,
90
+ ):
91
+ if api_key:
92
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
93
+ if base_url:
94
+ os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
95
+ if api_version:
96
+ os.environ["AZURE_OPENAI_API_VERSION"] = api_version
97
+
98
+ openai_async_client = AsyncAzureOpenAI(
99
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
100
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
101
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
102
+ )
103
+ kwargs.pop("hashing_kv", None)
104
+ messages = []
105
+ if system_prompt:
106
+ messages.append({"role": "system", "content": system_prompt})
107
+ messages.extend(history_messages)
108
+ if prompt is not None:
109
+ messages.append({"role": "user", "content": prompt})
110
+
111
+ if "response_format" in kwargs:
112
+ response = await openai_async_client.beta.chat.completions.parse(
113
+ model=model, messages=messages, **kwargs
114
+ )
115
+ else:
116
+ response = await openai_async_client.chat.completions.create(
117
+ model=model, messages=messages, **kwargs
118
+ )
119
+
120
+ if hasattr(response, "__aiter__"):
121
+
122
+ async def inner():
123
+ async for chunk in response:
124
+ if len(chunk.choices) == 0:
125
+ continue
126
+ content = chunk.choices[0].delta.content
127
+ if content is None:
128
+ continue
129
+ if r"\u" in content:
130
+ content = safe_unicode_decode(content.encode("utf-8"))
131
+ yield content
132
+
133
+ return inner()
134
+ else:
135
+ content = response.choices[0].message.content
136
+ if r"\u" in content:
137
+ content = safe_unicode_decode(content.encode("utf-8"))
138
+ return content
139
+
140
+
141
+ async def azure_openai_complete(
142
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
143
+ ) -> str:
144
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
145
+ result = await azure_openai_complete_if_cache(
146
+ os.getenv("LLM_MODEL", "gpt-4o-mini"),
147
+ prompt,
148
+ system_prompt=system_prompt,
149
+ history_messages=history_messages,
150
+ **kwargs,
151
+ )
152
+ if keyword_extraction: # TODO: use JSON API
153
+ return locate_json_string_body_from_string(result)
154
+ return result
155
+
156
+ @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
157
+ @retry(
158
+ stop=stop_after_attempt(3),
159
+ wait=wait_exponential(multiplier=1, min=4, max=10),
160
+ retry=retry_if_exception_type(
161
+ (RateLimitError, APIConnectionError, APITimeoutError)
162
+ ),
163
+ )
164
+ async def azure_openai_embed(
165
+ texts: list[str],
166
+ model: str = "text-embedding-3-small",
167
+ base_url: str = None,
168
+ api_key: str = None,
169
+ api_version: str = None,
170
+ ) -> np.ndarray:
171
+ if api_key:
172
+ os.environ["AZURE_OPENAI_API_KEY"] = api_key
173
+ if base_url:
174
+ os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
175
+ if api_version:
176
+ os.environ["AZURE_OPENAI_API_VERSION"] = api_version
177
+
178
+ openai_async_client = AsyncAzureOpenAI(
179
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
180
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
181
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
182
+ )
183
+
184
+ response = await openai_async_client.embeddings.create(
185
+ model=model, input=texts, encoding_format="float"
186
+ )
187
+ return np.array([dp.embedding for dp in response.data])
188
+
lightrag/llm/bedrock.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Bedrock LLM Interface Module
3
+ ==========================
4
+
5
+ This module provides interfaces for interacting with Bedrock's language models,
6
+ including text generation and embedding capabilities.
7
+
8
+ Author: Lightrag team
9
+ Created: 2024-01-24
10
+ License: MIT License
11
+
12
+ Copyright (c) 2024 Lightrag
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ Version: 1.0.0
22
+
23
+ Change Log:
24
+ - 1.0.0 (2024-01-24): Initial release
25
+ * Added async chat completion support
26
+ * Added embedding generation
27
+ * Added stream response capability
28
+
29
+ Dependencies:
30
+ - aioboto3, tenacity
31
+ - numpy
32
+ - pipmaster
33
+ - Python >= 3.10
34
+
35
+ Usage:
36
+ from llm_interfaces.bebrock import bebrock_model_complete, bebrock_embed
37
+ """
38
+
39
+ __version__ = "1.0.0"
40
+ __author__ = "lightrag Team"
41
+ __status__ = "Production"
42
+
43
+
44
+ import sys
45
+ import copy
46
+ import os
47
+ import json
48
+
49
+ import pipmaster as pm # Pipmaster for dynamic library install
50
+ if not pm.is_installed("aioboto3"):
51
+ pm.install("aioboto3")
52
+ if not pm.is_installed("tenacity"):
53
+ pm.install("tenacity")
54
+ import aioboto3
55
+ import numpy as np
56
+ from tenacity import (
57
+ retry,
58
+ stop_after_attempt,
59
+ wait_exponential,
60
+ retry_if_exception_type,
61
+ )
62
+
63
+ from lightrag.exceptions import (
64
+ APIConnectionError,
65
+ RateLimitError,
66
+ APITimeoutError,
67
+ )
68
+ from lightrag.utils import (
69
+ locate_json_string_body_from_string,
70
+ )
71
+
72
+ class BedrockError(Exception):
73
+ """Generic error for issues related to Amazon Bedrock"""
74
+
75
+
76
+ @retry(
77
+ stop=stop_after_attempt(5),
78
+ wait=wait_exponential(multiplier=1, max=60),
79
+ retry=retry_if_exception_type((BedrockError)),
80
+ )
81
+ async def bedrock_complete_if_cache(
82
+ model,
83
+ prompt,
84
+ system_prompt=None,
85
+ history_messages=[],
86
+ aws_access_key_id=None,
87
+ aws_secret_access_key=None,
88
+ aws_session_token=None,
89
+ **kwargs,
90
+ ) -> str:
91
+ os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
92
+ "AWS_ACCESS_KEY_ID", aws_access_key_id
93
+ )
94
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
95
+ "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
96
+ )
97
+ os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
98
+ "AWS_SESSION_TOKEN", aws_session_token
99
+ )
100
+ kwargs.pop("hashing_kv", None)
101
+ # Fix message history format
102
+ messages = []
103
+ for history_message in history_messages:
104
+ message = copy.copy(history_message)
105
+ message["content"] = [{"text": message["content"]}]
106
+ messages.append(message)
107
+
108
+ # Add user prompt
109
+ messages.append({"role": "user", "content": [{"text": prompt}]})
110
+
111
+ # Initialize Converse API arguments
112
+ args = {"modelId": model, "messages": messages}
113
+
114
+ # Define system prompt
115
+ if system_prompt:
116
+ args["system"] = [{"text": system_prompt}]
117
+
118
+ # Map and set up inference parameters
119
+ inference_params_map = {
120
+ "max_tokens": "maxTokens",
121
+ "top_p": "topP",
122
+ "stop_sequences": "stopSequences",
123
+ }
124
+ if inference_params := list(
125
+ set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
126
+ ):
127
+ args["inferenceConfig"] = {}
128
+ for param in inference_params:
129
+ args["inferenceConfig"][inference_params_map.get(param, param)] = (
130
+ kwargs.pop(param)
131
+ )
132
+
133
+ # Call model via Converse API
134
+ session = aioboto3.Session()
135
+ async with session.client("bedrock-runtime") as bedrock_async_client:
136
+ try:
137
+ response = await bedrock_async_client.converse(**args, **kwargs)
138
+ except Exception as e:
139
+ raise BedrockError(e)
140
+
141
+ return response["output"]["message"]["content"][0]["text"]
142
+
143
+
144
+ async def bedrock_complete(
145
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
146
+ ) -> str:
147
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
148
+ result = await bedrock_complete_if_cache(
149
+ "anthropic.claude-3-haiku-20240307-v1:0",
150
+ prompt,
151
+ system_prompt=system_prompt,
152
+ history_messages=history_messages,
153
+ **kwargs,
154
+ )
155
+ if keyword_extraction: # TODO: use JSON API
156
+ return locate_json_string_body_from_string(result)
157
+ return result
158
+
159
+
160
+ # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
161
+ # @retry(
162
+ # stop=stop_after_attempt(3),
163
+ # wait=wait_exponential(multiplier=1, min=4, max=10),
164
+ # retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
165
+ # )
166
+ async def bedrock_embed(
167
+ texts: list[str],
168
+ model: str = "amazon.titan-embed-text-v2:0",
169
+ aws_access_key_id=None,
170
+ aws_secret_access_key=None,
171
+ aws_session_token=None,
172
+ ) -> np.ndarray:
173
+ os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
174
+ "AWS_ACCESS_KEY_ID", aws_access_key_id
175
+ )
176
+ os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
177
+ "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
178
+ )
179
+ os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
180
+ "AWS_SESSION_TOKEN", aws_session_token
181
+ )
182
+
183
+ session = aioboto3.Session()
184
+ async with session.client("bedrock-runtime") as bedrock_async_client:
185
+ if (model_provider := model.split(".")[0]) == "amazon":
186
+ embed_texts = []
187
+ for text in texts:
188
+ if "v2" in model:
189
+ body = json.dumps(
190
+ {
191
+ "inputText": text,
192
+ # 'dimensions': embedding_dim,
193
+ "embeddingTypes": ["float"],
194
+ }
195
+ )
196
+ elif "v1" in model:
197
+ body = json.dumps({"inputText": text})
198
+ else:
199
+ raise ValueError(f"Model {model} is not supported!")
200
+
201
+ response = await bedrock_async_client.invoke_model(
202
+ modelId=model,
203
+ body=body,
204
+ accept="application/json",
205
+ contentType="application/json",
206
+ )
207
+
208
+ response_body = await response.get("body").json()
209
+
210
+ embed_texts.append(response_body["embedding"])
211
+ elif model_provider == "cohere":
212
+ body = json.dumps(
213
+ {"texts": texts, "input_type": "search_document", "truncate": "NONE"}
214
+ )
215
+
216
+ response = await bedrock_async_client.invoke_model(
217
+ model=model,
218
+ body=body,
219
+ accept="application/json",
220
+ contentType="application/json",
221
+ )
222
+
223
+ response_body = json.loads(response.get("body").read())
224
+
225
+ embed_texts = response_body["embeddings"]
226
+ else:
227
+ raise ValueError(f"Model provider '{model_provider}' is not supported!")
228
+
229
+ return np.array(embed_texts)
lightrag/llm/hf.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hugging face LLM Interface Module
3
+ ==========================
4
+
5
+ This module provides interfaces for interacting with Hugging face's language models,
6
+ including text generation and embedding capabilities.
7
+
8
+ Author: Lightrag team
9
+ Created: 2024-01-24
10
+ License: MIT License
11
+
12
+ Copyright (c) 2024 Lightrag
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ Version: 1.0.0
22
+
23
+ Change Log:
24
+ - 1.0.0 (2024-01-24): Initial release
25
+ * Added async chat completion support
26
+ * Added embedding generation
27
+ * Added stream response capability
28
+
29
+ Dependencies:
30
+ - transformers
31
+ - numpy
32
+ - pipmaster
33
+ - Python >= 3.10
34
+
35
+ Usage:
36
+ from llm_interfaces.hf import hf_model_complete, hf_embed
37
+ """
38
+
39
+ __version__ = "1.0.0"
40
+ __author__ = "lightrag Team"
41
+ __status__ = "Production"
42
+
43
+ import copy
44
+ import os
45
+ import pipmaster as pm # Pipmaster for dynamic library install
46
+
47
+ # install specific modules
48
+ if not pm.is_installed("transformers"):
49
+ pm.install("transformers")
50
+ if not pm.is_installed("torch"):
51
+ pm.install("torch")
52
+ if not pm.is_installed("tenacity"):
53
+ pm.install("tenacity")
54
+
55
+ from transformers import AutoTokenizer, AutoModelForCausalLM
56
+ from functools import lru_cache
57
+ from tenacity import (
58
+ retry,
59
+ stop_after_attempt,
60
+ wait_exponential,
61
+ retry_if_exception_type,
62
+ )
63
+ from lightrag.exceptions import (
64
+ APIConnectionError,
65
+ RateLimitError,
66
+ APITimeoutError,
67
+ )
68
+ from lightrag.utils import (
69
+ locate_json_string_body_from_string,
70
+ )
71
+ import torch
72
+
73
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
74
+
75
+ @lru_cache(maxsize=1)
76
+ def initialize_hf_model(model_name):
77
+ hf_tokenizer = AutoTokenizer.from_pretrained(
78
+ model_name, device_map="auto", trust_remote_code=True
79
+ )
80
+ hf_model = AutoModelForCausalLM.from_pretrained(
81
+ model_name, device_map="auto", trust_remote_code=True
82
+ )
83
+ if hf_tokenizer.pad_token is None:
84
+ hf_tokenizer.pad_token = hf_tokenizer.eos_token
85
+
86
+ return hf_model, hf_tokenizer
87
+
88
+
89
+ @retry(
90
+ stop=stop_after_attempt(3),
91
+ wait=wait_exponential(multiplier=1, min=4, max=10),
92
+ retry=retry_if_exception_type(
93
+ (RateLimitError, APIConnectionError, APITimeoutError)
94
+ ),
95
+ )
96
+ async def hf_model_if_cache(
97
+ model,
98
+ prompt,
99
+ system_prompt=None,
100
+ history_messages=[],
101
+ **kwargs,
102
+ ) -> str:
103
+ model_name = model
104
+ hf_model, hf_tokenizer = initialize_hf_model(model_name)
105
+ messages = []
106
+ if system_prompt:
107
+ messages.append({"role": "system", "content": system_prompt})
108
+ messages.extend(history_messages)
109
+ messages.append({"role": "user", "content": prompt})
110
+ kwargs.pop("hashing_kv", None)
111
+ input_prompt = ""
112
+ try:
113
+ input_prompt = hf_tokenizer.apply_chat_template(
114
+ messages, tokenize=False, add_generation_prompt=True
115
+ )
116
+ except Exception:
117
+ try:
118
+ ori_message = copy.deepcopy(messages)
119
+ if messages[0]["role"] == "system":
120
+ messages[1]["content"] = (
121
+ "<system>"
122
+ + messages[0]["content"]
123
+ + "</system>\n"
124
+ + messages[1]["content"]
125
+ )
126
+ messages = messages[1:]
127
+ input_prompt = hf_tokenizer.apply_chat_template(
128
+ messages, tokenize=False, add_generation_prompt=True
129
+ )
130
+ except Exception:
131
+ len_message = len(ori_message)
132
+ for msgid in range(len_message):
133
+ input_prompt = (
134
+ input_prompt
135
+ + "<"
136
+ + ori_message[msgid]["role"]
137
+ + ">"
138
+ + ori_message[msgid]["content"]
139
+ + "</"
140
+ + ori_message[msgid]["role"]
141
+ + ">\n"
142
+ )
143
+
144
+ input_ids = hf_tokenizer(
145
+ input_prompt, return_tensors="pt", padding=True, truncation=True
146
+ ).to("cuda")
147
+ inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
148
+ output = hf_model.generate(
149
+ **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
150
+ )
151
+ response_text = hf_tokenizer.decode(
152
+ output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
153
+ )
154
+
155
+ return response_text
156
+
157
+
158
+
159
+ async def hf_model_complete(
160
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
161
+ ) -> str:
162
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
163
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
164
+ result = await hf_model_if_cache(
165
+ model_name,
166
+ prompt,
167
+ system_prompt=system_prompt,
168
+ history_messages=history_messages,
169
+ **kwargs,
170
+ )
171
+ if keyword_extraction: # TODO: use JSON API
172
+ return locate_json_string_body_from_string(result)
173
+ return result
174
+
175
+
176
+ async def hf_embed(texts: list[str], tokenizer, embed_model) -> np.ndarray:
177
+ device = next(embed_model.parameters()).device
178
+ input_ids = tokenizer(
179
+ texts, return_tensors="pt", padding=True, truncation=True
180
+ ).input_ids.to(device)
181
+ with torch.no_grad():
182
+ outputs = embed_model(input_ids)
183
+ embeddings = outputs.last_hidden_state.mean(dim=1)
184
+ if embeddings.dtype == torch.bfloat16:
185
+ return embeddings.detach().to(torch.float32).cpu().numpy()
186
+ else:
187
+ return embeddings.detach().cpu().numpy()
lightrag/llm/jina.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Jina Embedding Interface Module
3
+ ==========================
4
+
5
+ This module provides interfaces for interacting with jina system,
6
+ including embedding capabilities.
7
+
8
+ Author: Lightrag team
9
+ Created: 2024-01-24
10
+ License: MIT License
11
+
12
+ Copyright (c) 2024 Lightrag
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ Version: 1.0.0
22
+
23
+ Change Log:
24
+ - 1.0.0 (2024-01-24): Initial release
25
+ * Added embedding generation
26
+
27
+ Dependencies:
28
+ - tenacity
29
+ - numpy
30
+ - pipmaster
31
+ - Python >= 3.10
32
+
33
+ Usage:
34
+ from llm_interfaces.jina import jina_embed
35
+ """
36
+
37
+ __version__ = "1.0.0"
38
+ __author__ = "lightrag Team"
39
+ __status__ = "Production"
40
+
41
+ import os
42
+ import pipmaster as pm # Pipmaster for dynamic library install
43
+
44
+ # install specific modules
45
+ if not pm.is_installed("lmdeploy"):
46
+ pm.install("lmdeploy")
47
+ if not pm.is_installed("tenacity"):
48
+ pm.install("tenacity")
49
+
50
+ from tenacity import (
51
+ retry,
52
+ stop_after_attempt,
53
+ wait_exponential,
54
+ retry_if_exception_type,
55
+ )
56
+
57
+ from lightrag.utils import (
58
+ wrap_embedding_func_with_attrs,
59
+ locate_json_string_body_from_string,
60
+ safe_unicode_decode,
61
+ logger,
62
+ )
63
+
64
+ from lightrag.types import GPTKeywordExtractionFormat
65
+ from functools import lru_cache
66
+
67
+ import numpy as np
68
+ from typing import Union
69
+ import aiohttp
70
+
71
+
72
+ async def fetch_data(url, headers, data):
73
+ async with aiohttp.ClientSession() as session:
74
+ async with session.post(url, headers=headers, json=data) as response:
75
+ response_json = await response.json()
76
+ data_list = response_json.get("data", [])
77
+ return data_list
78
+
79
+
80
+ async def jina_embed(
81
+ texts: list[str],
82
+ dimensions: int = 1024,
83
+ late_chunking: bool = False,
84
+ base_url: str = None,
85
+ api_key: str = None,
86
+ ) -> np.ndarray:
87
+ if api_key:
88
+ os.environ["JINA_API_KEY"] = api_key
89
+ url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
90
+ headers = {
91
+ "Content-Type": "application/json",
92
+ "Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
93
+ }
94
+ data = {
95
+ "model": "jina-embeddings-v3",
96
+ "normalized": True,
97
+ "embedding_type": "float",
98
+ "dimensions": f"{dimensions}",
99
+ "late_chunking": late_chunking,
100
+ "input": texts,
101
+ }
102
+ data_list = await fetch_data(url, headers, data)
103
+ return np.array([dp["embedding"] for dp in data_list])
104
+
lightrag/llm/lmdeploy.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LMDeploy LLM Interface Module
3
+ ==========================
4
+
5
+ This module provides interfaces for interacting with LMDeploy's language models,
6
+ including text generation and embedding capabilities.
7
+
8
+ Author: Lightrag team
9
+ Created: 2024-01-24
10
+ License: MIT License
11
+
12
+ Copyright (c) 2024 Lightrag
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ Version: 1.0.0
22
+
23
+ Change Log:
24
+ - 1.0.0 (2024-01-24): Initial release
25
+ * Added async chat completion support
26
+ * Added embedding generation
27
+ * Added stream response capability
28
+
29
+ Dependencies:
30
+ - tenacity
31
+ - numpy
32
+ - pipmaster
33
+ - Python >= 3.10
34
+
35
+ Usage:
36
+ from llm_interfaces.lmdeploy import lmdeploy_model_complete, lmdeploy_embed
37
+ """
38
+
39
+ __version__ = "1.0.0"
40
+ __author__ = "lightrag Team"
41
+ __status__ = "Production"
42
+
43
+ import pipmaster as pm # Pipmaster for dynamic library install
44
+
45
+ # install specific modules
46
+ if not pm.is_installed("lmdeploy"):
47
+ pm.install("lmdeploy[all]")
48
+ if not pm.is_installed("tenacity"):
49
+ pm.install("tenacity")
50
+
51
+ from lightrag.exceptions import (
52
+ APIConnectionError,
53
+ RateLimitError,
54
+ APITimeoutError,
55
+ )
56
+ from tenacity import (
57
+ retry,
58
+ stop_after_attempt,
59
+ wait_exponential,
60
+ retry_if_exception_type,
61
+ )
62
+
63
+
64
+ from functools import lru_cache
65
+
66
+ @lru_cache(maxsize=1)
67
+ def initialize_lmdeploy_pipeline(
68
+ model,
69
+ tp=1,
70
+ chat_template=None,
71
+ log_level="WARNING",
72
+ model_format="hf",
73
+ quant_policy=0,
74
+ ):
75
+ from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
76
+
77
+ lmdeploy_pipe = pipeline(
78
+ model_path=model,
79
+ backend_config=TurbomindEngineConfig(
80
+ tp=tp, model_format=model_format, quant_policy=quant_policy
81
+ ),
82
+ chat_template_config=(
83
+ ChatTemplateConfig(model_name=chat_template) if chat_template else None
84
+ ),
85
+ log_level="WARNING",
86
+ )
87
+ return lmdeploy_pipe
88
+
89
+
90
+ @retry(
91
+ stop=stop_after_attempt(3),
92
+ wait=wait_exponential(multiplier=1, min=4, max=10),
93
+ retry=retry_if_exception_type(
94
+ (RateLimitError, APIConnectionError, APITimeoutError)
95
+ ),
96
+ )
97
+ async def lmdeploy_model_if_cache(
98
+ model,
99
+ prompt,
100
+ system_prompt=None,
101
+ history_messages=[],
102
+ chat_template=None,
103
+ model_format="hf",
104
+ quant_policy=0,
105
+ **kwargs,
106
+ ) -> str:
107
+ """
108
+ Args:
109
+ model (str): The path to the model.
110
+ It could be one of the following options:
111
+ - i) A local directory path of a turbomind model which is
112
+ converted by `lmdeploy convert` command or download
113
+ from ii) and iii).
114
+ - ii) The model_id of a lmdeploy-quantized model hosted
115
+ inside a model repo on huggingface.co, such as
116
+ "InternLM/internlm-chat-20b-4bit",
117
+ "lmdeploy/llama2-chat-70b-4bit", etc.
118
+ - iii) The model_id of a model hosted inside a model repo
119
+ on huggingface.co, such as "internlm/internlm-chat-7b",
120
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
121
+ and so on.
122
+ chat_template (str): needed when model is a pytorch model on
123
+ huggingface.co, such as "internlm-chat-7b",
124
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
125
+ and when the model name of local path did not match the original model name in HF.
126
+ tp (int): tensor parallel
127
+ prompt (Union[str, List[str]]): input texts to be completed.
128
+ do_preprocess (bool): whether pre-process the messages. Default to
129
+ True, which means chat_template will be applied.
130
+ skip_special_tokens (bool): Whether or not to remove special tokens
131
+ in the decoding. Default to be True.
132
+ do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
133
+ Default to be False, which means greedy decoding will be applied.
134
+ """
135
+ try:
136
+ import lmdeploy
137
+ from lmdeploy import version_info, GenerationConfig
138
+ except Exception:
139
+ raise ImportError("Please install lmdeploy before initialize lmdeploy backend.")
140
+ kwargs.pop("hashing_kv", None)
141
+ kwargs.pop("response_format", None)
142
+ max_new_tokens = kwargs.pop("max_tokens", 512)
143
+ tp = kwargs.pop("tp", 1)
144
+ skip_special_tokens = kwargs.pop("skip_special_tokens", True)
145
+ do_preprocess = kwargs.pop("do_preprocess", True)
146
+ do_sample = kwargs.pop("do_sample", False)
147
+ gen_params = kwargs
148
+
149
+ version = version_info
150
+ if do_sample is not None and version < (0, 6, 0):
151
+ raise RuntimeError(
152
+ "`do_sample` parameter is not supported by lmdeploy until "
153
+ f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
154
+ )
155
+ else:
156
+ do_sample = True
157
+ gen_params.update(do_sample=do_sample)
158
+
159
+ lmdeploy_pipe = initialize_lmdeploy_pipeline(
160
+ model=model,
161
+ tp=tp,
162
+ chat_template=chat_template,
163
+ model_format=model_format,
164
+ quant_policy=quant_policy,
165
+ log_level="WARNING",
166
+ )
167
+
168
+ messages = []
169
+ if system_prompt:
170
+ messages.append({"role": "system", "content": system_prompt})
171
+
172
+ messages.extend(history_messages)
173
+ messages.append({"role": "user", "content": prompt})
174
+
175
+ gen_config = GenerationConfig(
176
+ skip_special_tokens=skip_special_tokens,
177
+ max_new_tokens=max_new_tokens,
178
+ **gen_params,
179
+ )
180
+
181
+ response = ""
182
+ async for res in lmdeploy_pipe.generate(
183
+ messages,
184
+ gen_config=gen_config,
185
+ do_preprocess=do_preprocess,
186
+ stream_response=False,
187
+ session_id=1,
188
+ ):
189
+ response += res.response
190
+ return response
lightrag/llm/lollms.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LoLLMs (Lord of Large Language Models) Interface Module
3
+ =====================================================
4
+
5
+ This module provides the official interface for interacting with LoLLMs (Lord of Large Language and multimodal Systems),
6
+ a unified framework for AI model interaction and deployment.
7
+
8
+ LoLLMs is designed as a "one tool to rule them all" solution, providing seamless integration
9
+ with various AI models while maintaining high performance and user-friendly interfaces.
10
+
11
+ Author: ParisNeo
12
+ Created: 2024-01-24
13
+ License: Apache 2.0
14
+
15
+ Copyright (c) 2024 ParisNeo
16
+
17
+ Licensed under the Apache License, Version 2.0 (the "License");
18
+ you may not use this file except in compliance with the License.
19
+ You may obtain a copy of the License at
20
+
21
+ http://www.apache.org/licenses/LICENSE-2.0
22
+
23
+ Unless required by applicable law or agreed to in writing, software
24
+ distributed under the License is distributed on an "AS IS" BASIS,
25
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ See the License for the specific language governing permissions and
27
+ limitations under the License.
28
+
29
+ Version: 2.0.0
30
+
31
+ Change Log:
32
+ - 2.0.0 (2024-01-24):
33
+ * Added async support for model inference
34
+ * Implemented streaming capabilities
35
+ * Added embedding generation functionality
36
+ * Enhanced parameter handling
37
+ * Improved error handling and timeout management
38
+
39
+ Dependencies:
40
+ - aiohttp
41
+ - numpy
42
+ - Python >= 3.10
43
+
44
+ Features:
45
+ - Async text generation with streaming support
46
+ - Embedding generation
47
+ - Configurable model parameters
48
+ - System prompt and chat history support
49
+ - Timeout handling
50
+ - API key authentication
51
+
52
+ Usage:
53
+ from llm_interfaces.lollms import lollms_model_complete, lollms_embed
54
+
55
+ Project Repository: https://github.com/ParisNeo/lollms
56
+ Documentation: https://github.com/ParisNeo/lollms/docs
57
+ """
58
+
59
+ __version__ = "1.0.0"
60
+ __author__ = "ParisNeo"
61
+ __status__ = "Production"
62
+ __project_url__ = "https://github.com/ParisNeo/lollms"
63
+ __doc_url__ = "https://github.com/ParisNeo/lollms/docs"
64
+ import sys
65
+ if sys.version_info < (3, 9):
66
+ from typing import AsyncIterator
67
+ else:
68
+ from collections.abc import AsyncIterator
69
+ import pipmaster as pm # Pipmaster for dynamic library install
70
+ if not pm.is_installed("aiohttp"):
71
+ pm.install("aiohttp")
72
+ if not pm.is_installed("tenacity"):
73
+ pm.install("tenacity")
74
+
75
+ import aiohttp
76
+ from tenacity import (
77
+ retry,
78
+ stop_after_attempt,
79
+ wait_exponential,
80
+ retry_if_exception_type,
81
+ )
82
+
83
+ from lightrag.exceptions import (
84
+ APIConnectionError,
85
+ RateLimitError,
86
+ APITimeoutError,
87
+ )
88
+
89
+ from typing import Union, List
90
+ import numpy as np
91
+
92
+ @retry(
93
+ stop=stop_after_attempt(3),
94
+ wait=wait_exponential(multiplier=1, min=4, max=10),
95
+ retry=retry_if_exception_type(
96
+ (RateLimitError, APIConnectionError, APITimeoutError)
97
+ ),
98
+ )
99
+ async def lollms_model_if_cache(
100
+ model,
101
+ prompt,
102
+ system_prompt=None,
103
+ history_messages=[],
104
+ base_url="http://localhost:9600",
105
+ **kwargs,
106
+ ) -> Union[str, AsyncIterator[str]]:
107
+ """Client implementation for lollms generation."""
108
+
109
+ stream = True if kwargs.get("stream") else False
110
+ api_key = kwargs.pop("api_key", None)
111
+ headers = (
112
+ {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
113
+ if api_key
114
+ else {"Content-Type": "application/json"}
115
+ )
116
+
117
+ # Extract lollms specific parameters
118
+ request_data = {
119
+ "prompt": prompt,
120
+ "model_name": model,
121
+ "personality": kwargs.get("personality", -1),
122
+ "n_predict": kwargs.get("n_predict", None),
123
+ "stream": stream,
124
+ "temperature": kwargs.get("temperature", 0.1),
125
+ "top_k": kwargs.get("top_k", 50),
126
+ "top_p": kwargs.get("top_p", 0.95),
127
+ "repeat_penalty": kwargs.get("repeat_penalty", 0.8),
128
+ "repeat_last_n": kwargs.get("repeat_last_n", 40),
129
+ "seed": kwargs.get("seed", None),
130
+ "n_threads": kwargs.get("n_threads", 8),
131
+ }
132
+
133
+ # Prepare the full prompt including history
134
+ full_prompt = ""
135
+ if system_prompt:
136
+ full_prompt += f"{system_prompt}\n"
137
+ for msg in history_messages:
138
+ full_prompt += f"{msg['role']}: {msg['content']}\n"
139
+ full_prompt += prompt
140
+
141
+ request_data["prompt"] = full_prompt
142
+ timeout = aiohttp.ClientTimeout(total=kwargs.get("timeout", None))
143
+
144
+ async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
145
+ if stream:
146
+
147
+ async def inner():
148
+ async with session.post(
149
+ f"{base_url}/lollms_generate", json=request_data
150
+ ) as response:
151
+ async for line in response.content:
152
+ yield line.decode().strip()
153
+
154
+ return inner()
155
+ else:
156
+ async with session.post(
157
+ f"{base_url}/lollms_generate", json=request_data
158
+ ) as response:
159
+ return await response.text()
160
+
161
+
162
+ async def lollms_model_complete(
163
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
164
+ ) -> Union[str, AsyncIterator[str]]:
165
+ """Complete function for lollms model generation."""
166
+
167
+ # Extract and remove keyword_extraction from kwargs if present
168
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
169
+
170
+ # Get model name from config
171
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
172
+
173
+ # If keyword extraction is needed, we might need to modify the prompt
174
+ # or add specific parameters for JSON output (if lollms supports it)
175
+ if keyword_extraction:
176
+ # Note: You might need to adjust this based on how lollms handles structured output
177
+ pass
178
+
179
+ return await lollms_model_if_cache(
180
+ model_name,
181
+ prompt,
182
+ system_prompt=system_prompt,
183
+ history_messages=history_messages,
184
+ **kwargs,
185
+ )
186
+
187
+
188
+
189
+ async def lollms_embed(
190
+ texts: List[str], embed_model=None, base_url="http://localhost:9600", **kwargs
191
+ ) -> np.ndarray:
192
+ """
193
+ Generate embeddings for a list of texts using lollms server.
194
+
195
+ Args:
196
+ texts: List of strings to embed
197
+ embed_model: Model name (not used directly as lollms uses configured vectorizer)
198
+ base_url: URL of the lollms server
199
+ **kwargs: Additional arguments passed to the request
200
+
201
+ Returns:
202
+ np.ndarray: Array of embeddings
203
+ """
204
+ api_key = kwargs.pop("api_key", None)
205
+ headers = (
206
+ {"Content-Type": "application/json", "Authorization": api_key}
207
+ if api_key
208
+ else {"Content-Type": "application/json"}
209
+ )
210
+ async with aiohttp.ClientSession(headers=headers) as session:
211
+ embeddings = []
212
+ for text in texts:
213
+ request_data = {"text": text}
214
+
215
+ async with session.post(
216
+ f"{base_url}/lollms_embed",
217
+ json=request_data,
218
+ ) as response:
219
+ result = await response.json()
220
+ embeddings.append(result["vector"])
221
+
222
+ return np.array(embeddings)
lightrag/llm/nvidia_openai.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI LLM Interface Module
3
+ ==========================
4
+
5
+ This module provides interfaces for interacting with openai's language models,
6
+ including text generation and embedding capabilities.
7
+
8
+ Author: Lightrag team
9
+ Created: 2024-01-24
10
+ License: MIT License
11
+
12
+ Copyright (c) 2024 Lightrag
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ Version: 1.0.0
22
+
23
+ Change Log:
24
+ - 1.0.0 (2024-01-24): Initial release
25
+ * Added async chat completion support
26
+ * Added embedding generation
27
+ * Added stream response capability
28
+
29
+ Dependencies:
30
+ - openai
31
+ - numpy
32
+ - pipmaster
33
+ - Python >= 3.10
34
+
35
+ Usage:
36
+ from llm_interfaces.nvidia_openai import nvidia_openai_model_complete, nvidia_openai_embed
37
+ """
38
+
39
+ __version__ = "1.0.0"
40
+ __author__ = "lightrag Team"
41
+ __status__ = "Production"
42
+
43
+
44
+
45
+ import sys
46
+ import os
47
+
48
+ if sys.version_info < (3, 9):
49
+ from typing import AsyncIterator
50
+ else:
51
+ from collections.abc import AsyncIterator
52
+ import pipmaster as pm # Pipmaster for dynamic library install
53
+
54
+ # install specific modules
55
+ if not pm.is_installed("openai"):
56
+ pm.install("openai")
57
+
58
+ from openai import (
59
+ AsyncOpenAI,
60
+ APIConnectionError,
61
+ RateLimitError,
62
+ APITimeoutError,
63
+ )
64
+ from tenacity import (
65
+ retry,
66
+ stop_after_attempt,
67
+ wait_exponential,
68
+ retry_if_exception_type,
69
+ )
70
+
71
+ from lightrag.utils import (
72
+ wrap_embedding_func_with_attrs,
73
+ locate_json_string_body_from_string,
74
+ safe_unicode_decode,
75
+ logger,
76
+ )
77
+
78
+ from lightrag.types import GPTKeywordExtractionFormat
79
+
80
+ import numpy as np
81
+
82
+ @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
83
+ @retry(
84
+ stop=stop_after_attempt(3),
85
+ wait=wait_exponential(multiplier=1, min=4, max=60),
86
+ retry=retry_if_exception_type(
87
+ (RateLimitError, APIConnectionError, APITimeoutError)
88
+ ),
89
+ )
90
+ async def nvidia_openai_embed(
91
+ texts: list[str],
92
+ model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
93
+ # refer to https://build.nvidia.com/nim?filters=usecase%3Ausecase_text_to_embedding
94
+ base_url: str = "https://integrate.api.nvidia.com/v1",
95
+ api_key: str = None,
96
+ input_type: str = "passage", # query for retrieval, passage for embedding
97
+ trunc: str = "NONE", # NONE or START or END
98
+ encode: str = "float", # float or base64
99
+ ) -> np.ndarray:
100
+ if api_key:
101
+ os.environ["OPENAI_API_KEY"] = api_key
102
+
103
+ openai_async_client = (
104
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
105
+ )
106
+ response = await openai_async_client.embeddings.create(
107
+ model=model,
108
+ input=texts,
109
+ encoding_format=encode,
110
+ extra_body={"input_type": input_type, "truncate": trunc},
111
+ )
112
+ return np.array([dp.embedding for dp in response.data])
lightrag/llm/ollama.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ollama LLM Interface Module
3
+ ==========================
4
+
5
+ This module provides interfaces for interacting with Ollama's language models,
6
+ including text generation and embedding capabilities.
7
+
8
+ Author: Lightrag team
9
+ Created: 2024-01-24
10
+ License: MIT License
11
+
12
+ Copyright (c) 2024 Lightrag
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ Version: 1.0.0
22
+
23
+ Change Log:
24
+ - 1.0.0 (2024-01-24): Initial release
25
+ * Added async chat completion support
26
+ * Added embedding generation
27
+ * Added stream response capability
28
+
29
+ Dependencies:
30
+ - ollama
31
+ - numpy
32
+ - pipmaster
33
+ - Python >= 3.10
34
+
35
+ Usage:
36
+ from llm_interfaces.ollama_interface import ollama_model_complete, ollama_embed
37
+ """
38
+
39
+ __version__ = "1.0.0"
40
+ __author__ = "lightrag Team"
41
+ __status__ = "Production"
42
+
43
+ import sys
44
+ if sys.version_info < (3, 9):
45
+ from typing import AsyncIterator
46
+ else:
47
+ from collections.abc import AsyncIterator
48
+ import pipmaster as pm # Pipmaster for dynamic library install
49
+
50
+ # install specific modules
51
+ if not pm.is_installed("ollama"):
52
+ pm.install("ollama")
53
+ if not pm.is_installed("tenacity"):
54
+ pm.install("tenacity")
55
+
56
+ import ollama
57
+ from tenacity import (
58
+ retry,
59
+ stop_after_attempt,
60
+ wait_exponential,
61
+ retry_if_exception_type,
62
+ )
63
+ from lightrag.exceptions import (
64
+ APIConnectionError,
65
+ RateLimitError,
66
+ APITimeoutError,
67
+ )
68
+ import numpy as np
69
+ from typing import Union
70
+
71
+
72
+ @retry(
73
+ stop=stop_after_attempt(3),
74
+ wait=wait_exponential(multiplier=1, min=4, max=10),
75
+ retry=retry_if_exception_type(
76
+ (RateLimitError, APIConnectionError, APITimeoutError)
77
+ ),
78
+ )
79
+ async def ollama_model_if_cache(
80
+ model,
81
+ prompt,
82
+ system_prompt=None,
83
+ history_messages=[],
84
+ **kwargs,
85
+ ) -> Union[str, AsyncIterator[str]]:
86
+ stream = True if kwargs.get("stream") else False
87
+ kwargs.pop("max_tokens", None)
88
+ # kwargs.pop("response_format", None) # allow json
89
+ host = kwargs.pop("host", None)
90
+ timeout = kwargs.pop("timeout", None)
91
+ kwargs.pop("hashing_kv", None)
92
+ api_key = kwargs.pop("api_key", None)
93
+ headers = (
94
+ {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
95
+ if api_key
96
+ else {"Content-Type": "application/json"}
97
+ )
98
+ ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
99
+ messages = []
100
+ if system_prompt:
101
+ messages.append({"role": "system", "content": system_prompt})
102
+ messages.extend(history_messages)
103
+ messages.append({"role": "user", "content": prompt})
104
+
105
+ response = await ollama_client.chat(model=model, messages=messages, **kwargs)
106
+ if stream:
107
+ """cannot cache stream response"""
108
+
109
+ async def inner():
110
+ async for chunk in response:
111
+ yield chunk["message"]["content"]
112
+
113
+ return inner()
114
+ else:
115
+ return response["message"]["content"]
116
+
117
+ async def ollama_model_complete(
118
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
119
+ ) -> Union[str, AsyncIterator[str]]:
120
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
121
+ if keyword_extraction:
122
+ kwargs["format"] = "json"
123
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
124
+ return await ollama_model_if_cache(
125
+ model_name,
126
+ prompt,
127
+ system_prompt=system_prompt,
128
+ history_messages=history_messages,
129
+ **kwargs,
130
+ )
131
+
132
+ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
133
+ """
134
+ Deprecated in favor of `embed`.
135
+ """
136
+ embed_text = []
137
+ ollama_client = ollama.Client(**kwargs)
138
+ for text in texts:
139
+ data = ollama_client.embeddings(model=embed_model, prompt=text)
140
+ embed_text.append(data["embedding"])
141
+
142
+ return embed_text
143
+
144
+
145
+ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
146
+ api_key = kwargs.pop("api_key", None)
147
+ headers = (
148
+ {"Content-Type": "application/json", "Authorization": api_key}
149
+ if api_key
150
+ else {"Content-Type": "application/json"}
151
+ )
152
+ kwargs["headers"] = headers
153
+ ollama_client = ollama.Client(**kwargs)
154
+ data = ollama_client.embed(model=embed_model, input=texts)
155
+ return data["embeddings"]
lightrag/llm/openai.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenAI LLM Interface Module
3
+ ==========================
4
+
5
+ This module provides interfaces for interacting with openai's language models,
6
+ including text generation and embedding capabilities.
7
+
8
+ Author: Lightrag team
9
+ Created: 2024-01-24
10
+ License: MIT License
11
+
12
+ Copyright (c) 2024 Lightrag
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ Version: 1.0.0
22
+
23
+ Change Log:
24
+ - 1.0.0 (2024-01-24): Initial release
25
+ * Added async chat completion support
26
+ * Added embedding generation
27
+ * Added stream response capability
28
+
29
+ Dependencies:
30
+ - openai
31
+ - numpy
32
+ - pipmaster
33
+ - Python >= 3.10
34
+
35
+ Usage:
36
+ from llm_interfaces.openai import openai_model_complete, openai_embed
37
+ """
38
+
39
+ __version__ = "1.0.0"
40
+ __author__ = "lightrag Team"
41
+ __status__ = "Production"
42
+
43
+
44
+
45
+ import sys
46
+ import os
47
+
48
+ if sys.version_info < (3, 9):
49
+ from typing import AsyncIterator
50
+ else:
51
+ from collections.abc import AsyncIterator
52
+ import pipmaster as pm # Pipmaster for dynamic library install
53
+
54
+ # install specific modules
55
+ if not pm.is_installed("openai"):
56
+ pm.install("openai")
57
+
58
+ from openai import (
59
+ AsyncOpenAI,
60
+ APIConnectionError,
61
+ RateLimitError,
62
+ APITimeoutError,
63
+ )
64
+ from tenacity import (
65
+ retry,
66
+ stop_after_attempt,
67
+ wait_exponential,
68
+ retry_if_exception_type,
69
+ )
70
+ from lightrag.utils import (
71
+ wrap_embedding_func_with_attrs,
72
+ locate_json_string_body_from_string,
73
+ safe_unicode_decode,
74
+ logger,
75
+ )
76
+ from lightrag.types import GPTKeywordExtractionFormat
77
+
78
+ import numpy as np
79
+ from typing import Union
80
+
81
+ @retry(
82
+ stop=stop_after_attempt(3),
83
+ wait=wait_exponential(multiplier=1, min=4, max=10),
84
+ retry=retry_if_exception_type(
85
+ (RateLimitError, APIConnectionError, APITimeoutError)
86
+ ),
87
+ )
88
+ async def openai_complete_if_cache(
89
+ model,
90
+ prompt,
91
+ system_prompt=None,
92
+ history_messages=[],
93
+ base_url=None,
94
+ api_key=None,
95
+ **kwargs,
96
+ ) -> str:
97
+ if api_key:
98
+ os.environ["OPENAI_API_KEY"] = api_key
99
+
100
+ openai_async_client = (
101
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
102
+ )
103
+ kwargs.pop("hashing_kv", None)
104
+ kwargs.pop("keyword_extraction", None)
105
+ messages = []
106
+ if system_prompt:
107
+ messages.append({"role": "system", "content": system_prompt})
108
+ messages.extend(history_messages)
109
+ messages.append({"role": "user", "content": prompt})
110
+
111
+ # 添加日志输出
112
+ logger.debug("===== Query Input to LLM =====")
113
+ logger.debug(f"Query: {prompt}")
114
+ logger.debug(f"System prompt: {system_prompt}")
115
+ logger.debug("Full context:")
116
+ if "response_format" in kwargs:
117
+ response = await openai_async_client.beta.chat.completions.parse(
118
+ model=model, messages=messages, **kwargs
119
+ )
120
+ else:
121
+ response = await openai_async_client.chat.completions.create(
122
+ model=model, messages=messages, **kwargs
123
+ )
124
+
125
+ if hasattr(response, "__aiter__"):
126
+
127
+ async def inner():
128
+ async for chunk in response:
129
+ content = chunk.choices[0].delta.content
130
+ if content is None:
131
+ continue
132
+ if r"\u" in content:
133
+ content = safe_unicode_decode(content.encode("utf-8"))
134
+ yield content
135
+
136
+ return inner()
137
+ else:
138
+ content = response.choices[0].message.content
139
+ if r"\u" in content:
140
+ content = safe_unicode_decode(content.encode("utf-8"))
141
+ return content
142
+
143
+
144
+
145
+ async def openai_complete(
146
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
147
+ ) -> Union[str, AsyncIterator[str]]:
148
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
149
+ if keyword_extraction:
150
+ kwargs["response_format"] = "json"
151
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
152
+ return await openai_complete_if_cache(
153
+ model_name,
154
+ prompt,
155
+ system_prompt=system_prompt,
156
+ history_messages=history_messages,
157
+ **kwargs,
158
+ )
159
+
160
+
161
+ async def gpt_4o_complete(
162
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
163
+ ) -> str:
164
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
165
+ if keyword_extraction:
166
+ kwargs["response_format"] = GPTKeywordExtractionFormat
167
+ return await openai_complete_if_cache(
168
+ "gpt-4o",
169
+ prompt,
170
+ system_prompt=system_prompt,
171
+ history_messages=history_messages,
172
+ **kwargs,
173
+ )
174
+
175
+
176
+ async def gpt_4o_mini_complete(
177
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
178
+ ) -> str:
179
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
180
+ if keyword_extraction:
181
+ kwargs["response_format"] = GPTKeywordExtractionFormat
182
+ return await openai_complete_if_cache(
183
+ "gpt-4o-mini",
184
+ prompt,
185
+ system_prompt=system_prompt,
186
+ history_messages=history_messages,
187
+ **kwargs,
188
+ )
189
+
190
+
191
+ async def nvidia_openai_complete(
192
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
193
+ ) -> str:
194
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
195
+ result = await openai_complete_if_cache(
196
+ "nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
197
+ prompt,
198
+ system_prompt=system_prompt,
199
+ history_messages=history_messages,
200
+ base_url="https://integrate.api.nvidia.com/v1",
201
+ **kwargs,
202
+ )
203
+ if keyword_extraction: # TODO: use JSON API
204
+ return locate_json_string_body_from_string(result)
205
+ return result
206
+
207
+
208
+
209
+ @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
210
+ @retry(
211
+ stop=stop_after_attempt(3),
212
+ wait=wait_exponential(multiplier=1, min=4, max=60),
213
+ retry=retry_if_exception_type(
214
+ (RateLimitError, APIConnectionError, APITimeoutError)
215
+ ),
216
+ )
217
+ async def openai_embed(
218
+ texts: list[str],
219
+ model: str = "text-embedding-3-small",
220
+ base_url: str = None,
221
+ api_key: str = None,
222
+ ) -> np.ndarray:
223
+ if api_key:
224
+ os.environ["OPENAI_API_KEY"] = api_key
225
+
226
+ openai_async_client = (
227
+ AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
228
+ )
229
+ response = await openai_async_client.embeddings.create(
230
+ model=model, input=texts, encoding_format="float"
231
+ )
232
+ return np.array([dp.embedding for dp in response.data])
lightrag/llm/siliconcloud.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SiliconCloud Embedding Interface Module
3
+ ==========================
4
+
5
+ This module provides interfaces for interacting with SiliconCloud system,
6
+ including embedding capabilities.
7
+
8
+ Author: Lightrag team
9
+ Created: 2024-01-24
10
+ License: MIT License
11
+
12
+ Copyright (c) 2024 Lightrag
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ Version: 1.0.0
22
+
23
+ Change Log:
24
+ - 1.0.0 (2024-01-24): Initial release
25
+ * Added embedding generation
26
+
27
+ Dependencies:
28
+ - tenacity
29
+ - numpy
30
+ - pipmaster
31
+ - Python >= 3.10
32
+
33
+ Usage:
34
+ from llm_interfaces.siliconcloud import siliconcloud_model_complete, siliconcloud_embed
35
+ """
36
+
37
+ __version__ = "1.0.0"
38
+ __author__ = "lightrag Team"
39
+ __status__ = "Production"
40
+
41
+ import sys
42
+ import copy
43
+ import os
44
+ import json
45
+
46
+ if sys.version_info < (3, 9):
47
+ from typing import AsyncIterator
48
+ else:
49
+ from collections.abc import AsyncIterator
50
+ import pipmaster as pm # Pipmaster for dynamic library install
51
+
52
+ # install specific modules
53
+ if not pm.is_installed("lmdeploy"):
54
+ pm.install("lmdeploy")
55
+
56
+ from openai import (
57
+ AsyncOpenAI,
58
+ AsyncAzureOpenAI,
59
+ APIConnectionError,
60
+ RateLimitError,
61
+ APITimeoutError,
62
+ )
63
+ from tenacity import (
64
+ retry,
65
+ stop_after_attempt,
66
+ wait_exponential,
67
+ retry_if_exception_type,
68
+ )
69
+
70
+ from lightrag.utils import (
71
+ wrap_embedding_func_with_attrs,
72
+ locate_json_string_body_from_string,
73
+ safe_unicode_decode,
74
+ logger,
75
+ )
76
+
77
+ from lightrag.types import GPTKeywordExtractionFormat
78
+ from functools import lru_cache
79
+
80
+ import numpy as np
81
+ from typing import Union
82
+ import aiohttp
83
+
84
+ @retry(
85
+ stop=stop_after_attempt(3),
86
+ wait=wait_exponential(multiplier=1, min=4, max=60),
87
+ retry=retry_if_exception_type(
88
+ (RateLimitError, APIConnectionError, APITimeoutError)
89
+ ),
90
+ )
91
+ async def siliconcloud_embedding(
92
+ texts: list[str],
93
+ model: str = "netease-youdao/bce-embedding-base_v1",
94
+ base_url: str = "https://api.siliconflow.cn/v1/embeddings",
95
+ max_token_size: int = 512,
96
+ api_key: str = None,
97
+ ) -> np.ndarray:
98
+ if api_key and not api_key.startswith("Bearer "):
99
+ api_key = "Bearer " + api_key
100
+
101
+ headers = {"Authorization": api_key, "Content-Type": "application/json"}
102
+
103
+ truncate_texts = [text[0:max_token_size] for text in texts]
104
+
105
+ payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
106
+
107
+ base64_strings = []
108
+ async with aiohttp.ClientSession() as session:
109
+ async with session.post(base_url, headers=headers, json=payload) as response:
110
+ content = await response.json()
111
+ if "code" in content:
112
+ raise ValueError(content)
113
+ base64_strings = [item["embedding"] for item in content["data"]]
114
+
115
+ embeddings = []
116
+ for string in base64_strings:
117
+ decode_bytes = base64.b64decode(string)
118
+ n = len(decode_bytes) // 4
119
+ float_array = struct.unpack("<" + "f" * n, decode_bytes)
120
+ embeddings.append(float_array)
121
+ return np.array(embeddings)
lightrag/llm/zhipu.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Zhipu LLM Interface Module
3
+ ==========================
4
+
5
+ This module provides interfaces for interacting with LMDeploy's language models,
6
+ including text generation and embedding capabilities.
7
+
8
+ Author: Lightrag team
9
+ Created: 2024-01-24
10
+ License: MIT License
11
+
12
+ Copyright (c) 2024 Lightrag
13
+
14
+ Permission is hereby granted, free of charge, to any person obtaining a copy
15
+ of this software and associated documentation files (the "Software"), to deal
16
+ in the Software without restriction, including without limitation the rights
17
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18
+ copies of the Software, and to permit persons to whom the Software is
19
+ furnished to do so, subject to the following conditions:
20
+
21
+ Version: 1.0.0
22
+
23
+ Change Log:
24
+ - 1.0.0 (2024-01-24): Initial release
25
+ * Added async chat completion support
26
+ * Added embedding generation
27
+ * Added stream response capability
28
+
29
+ Dependencies:
30
+ - tenacity
31
+ - numpy
32
+ - pipmaster
33
+ - Python >= 3.10
34
+
35
+ Usage:
36
+ from llm_interfaces.zhipu import zhipu_model_complete, zhipu_embed
37
+ """
38
+
39
+ __version__ = "1.0.0"
40
+ __author__ = "lightrag Team"
41
+ __status__ = "Production"
42
+
43
+ import sys
44
+ import re
45
+ import json
46
+
47
+ if sys.version_info < (3, 9):
48
+ from typing import AsyncIterator
49
+ else:
50
+ from collections.abc import AsyncIterator
51
+ import pipmaster as pm # Pipmaster for dynamic library install
52
+
53
+ # install specific modules
54
+ if not pm.is_installed("zhipuai"):
55
+ pm.install("zhipuai")
56
+
57
+ from openai import (
58
+ AsyncOpenAI,
59
+ AsyncAzureOpenAI,
60
+ APIConnectionError,
61
+ RateLimitError,
62
+ APITimeoutError,
63
+ )
64
+ from tenacity import (
65
+ retry,
66
+ stop_after_attempt,
67
+ wait_exponential,
68
+ retry_if_exception_type,
69
+ )
70
+
71
+ from lightrag.utils import (
72
+ wrap_embedding_func_with_attrs,
73
+ locate_json_string_body_from_string,
74
+ safe_unicode_decode,
75
+ logger,
76
+ )
77
+
78
+ from lightrag.types import GPTKeywordExtractionFormat
79
+ from functools import lru_cache
80
+
81
+ import numpy as np
82
+ from typing import Union, List, Optional, Dict
83
+
84
+ @retry(
85
+ stop=stop_after_attempt(3),
86
+ wait=wait_exponential(multiplier=1, min=4, max=10),
87
+ retry=retry_if_exception_type(
88
+ (RateLimitError, APIConnectionError, APITimeoutError)
89
+ ),
90
+ )
91
+ async def zhipu_complete_if_cache(
92
+ prompt: Union[str, List[Dict[str, str]]],
93
+ model: str = "glm-4-flashx", # The most cost/performance balance model in glm-4 series
94
+ api_key: Optional[str] = None,
95
+ system_prompt: Optional[str] = None,
96
+ history_messages: List[Dict[str, str]] = [],
97
+ **kwargs,
98
+ ) -> str:
99
+ # dynamically load ZhipuAI
100
+ try:
101
+ from zhipuai import ZhipuAI
102
+ except ImportError:
103
+ raise ImportError("Please install zhipuai before initialize zhipuai backend.")
104
+
105
+ if api_key:
106
+ client = ZhipuAI(api_key=api_key)
107
+ else:
108
+ # please set ZHIPUAI_API_KEY in your environment
109
+ # os.environ["ZHIPUAI_API_KEY"]
110
+ client = ZhipuAI()
111
+
112
+ messages = []
113
+
114
+ if not system_prompt:
115
+ system_prompt = "You are a helpful assistant. Note that sensitive words in the content should be replaced with ***"
116
+
117
+ # Add system prompt if provided
118
+ if system_prompt:
119
+ messages.append({"role": "system", "content": system_prompt})
120
+ messages.extend(history_messages)
121
+ messages.append({"role": "user", "content": prompt})
122
+
123
+ # Add debug logging
124
+ logger.debug("===== Query Input to LLM =====")
125
+ logger.debug(f"Query: {prompt}")
126
+ logger.debug(f"System prompt: {system_prompt}")
127
+
128
+ # Remove unsupported kwargs
129
+ kwargs = {
130
+ k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
131
+ }
132
+
133
+ response = client.chat.completions.create(model=model, messages=messages, **kwargs)
134
+
135
+ return response.choices[0].message.content
136
+
137
+
138
+ async def zhipu_complete(
139
+ prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
140
+ ):
141
+ # Pop keyword_extraction from kwargs to avoid passing it to zhipu_complete_if_cache
142
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
143
+
144
+ if keyword_extraction:
145
+ # Add a system prompt to guide the model to return JSON format
146
+ extraction_prompt = """You are a helpful assistant that extracts keywords from text.
147
+ Please analyze the content and extract two types of keywords:
148
+ 1. High-level keywords: Important concepts and main themes
149
+ 2. Low-level keywords: Specific details and supporting elements
150
+
151
+ Return your response in this exact JSON format:
152
+ {
153
+ "high_level_keywords": ["keyword1", "keyword2"],
154
+ "low_level_keywords": ["keyword1", "keyword2", "keyword3"]
155
+ }
156
+
157
+ Only return the JSON, no other text."""
158
+
159
+ # Combine with existing system prompt if any
160
+ if system_prompt:
161
+ system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
162
+ else:
163
+ system_prompt = extraction_prompt
164
+
165
+ try:
166
+ response = await zhipu_complete_if_cache(
167
+ prompt=prompt,
168
+ system_prompt=system_prompt,
169
+ history_messages=history_messages,
170
+ **kwargs,
171
+ )
172
+
173
+ # Try to parse as JSON
174
+ try:
175
+ data = json.loads(response)
176
+ return GPTKeywordExtractionFormat(
177
+ high_level_keywords=data.get("high_level_keywords", []),
178
+ low_level_keywords=data.get("low_level_keywords", []),
179
+ )
180
+ except json.JSONDecodeError:
181
+ # If direct JSON parsing fails, try to extract JSON from text
182
+ match = re.search(r"\{[\s\S]*\}", response)
183
+ if match:
184
+ try:
185
+ data = json.loads(match.group())
186
+ return GPTKeywordExtractionFormat(
187
+ high_level_keywords=data.get("high_level_keywords", []),
188
+ low_level_keywords=data.get("low_level_keywords", []),
189
+ )
190
+ except json.JSONDecodeError:
191
+ pass
192
+
193
+ # If all parsing fails, log warning and return empty format
194
+ logger.warning(
195
+ f"Failed to parse keyword extraction response: {response}"
196
+ )
197
+ return GPTKeywordExtractionFormat(
198
+ high_level_keywords=[], low_level_keywords=[]
199
+ )
200
+ except Exception as e:
201
+ logger.error(f"Error during keyword extraction: {str(e)}")
202
+ return GPTKeywordExtractionFormat(
203
+ high_level_keywords=[], low_level_keywords=[]
204
+ )
205
+ else:
206
+ # For non-keyword-extraction, just return the raw response string
207
+ return await zhipu_complete_if_cache(
208
+ prompt=prompt,
209
+ system_prompt=system_prompt,
210
+ history_messages=history_messages,
211
+ **kwargs,
212
+ )
213
+
214
+
215
+ @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
216
+ @retry(
217
+ stop=stop_after_attempt(3),
218
+ wait=wait_exponential(multiplier=1, min=4, max=60),
219
+ retry=retry_if_exception_type(
220
+ (RateLimitError, APIConnectionError, APITimeoutError)
221
+ ),
222
+ )
223
+ async def zhipu_embedding(
224
+ texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
225
+ ) -> np.ndarray:
226
+ # dynamically load ZhipuAI
227
+ try:
228
+ from zhipuai import ZhipuAI
229
+ except ImportError:
230
+ raise ImportError("Please install zhipuai before initialize zhipuai backend.")
231
+ if api_key:
232
+ client = ZhipuAI(api_key=api_key)
233
+ else:
234
+ # please set ZHIPUAI_API_KEY in your environment
235
+ # os.environ["ZHIPUAI_API_KEY"]
236
+ client = ZhipuAI()
237
+
238
+ # Convert single text to list if needed
239
+ if isinstance(texts, str):
240
+ texts = [texts]
241
+
242
+ embeddings = []
243
+ for text in texts:
244
+ try:
245
+ response = client.embeddings.create(model=model, input=[text], **kwargs)
246
+ embeddings.append(response.data[0].embedding)
247
+ except Exception as e:
248
+ raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
249
+
250
+ return np.array(embeddings)
lightrag/storage.py CHANGED
@@ -6,6 +6,8 @@ from dataclasses import dataclass
6
  from typing import Any, Union, cast, Dict
7
  import networkx as nx
8
  import numpy as np
 
 
9
  from nano_vectordb import NanoVectorDB
10
  import time
11
 
 
6
  from typing import Any, Union, cast, Dict
7
  import networkx as nx
8
  import numpy as np
9
+ import pipmaster as pm
10
+
11
  from nano_vectordb import NanoVectorDB
12
  import time
13