Daniel.y commited on
Commit
c38d821
·
unverified ·
2 Parent(s): 33b704f 48cf780

Merge pull request #1572 from danielaskdd/optimize-ollama

Browse files
README-zh.md CHANGED
@@ -415,7 +415,7 @@ rag = LightRAG(
415
  embedding_func=EmbeddingFunc(
416
  embedding_dim=768,
417
  max_token_size=8192,
418
- func=lambda texts: ollama_embedding(
419
  texts,
420
  embed_model="nomic-embed-text"
421
  )
 
415
  embedding_func=EmbeddingFunc(
416
  embedding_dim=768,
417
  max_token_size=8192,
418
+ func=lambda texts: ollama_embed(
419
  texts,
420
  embed_model="nomic-embed-text"
421
  )
README.md CHANGED
@@ -447,7 +447,7 @@ rag = LightRAG(
447
  embedding_func=EmbeddingFunc(
448
  embedding_dim=768,
449
  max_token_size=8192,
450
- func=lambda texts: ollama_embedding(
451
  texts,
452
  embed_model="nomic-embed-text"
453
  )
 
447
  embedding_func=EmbeddingFunc(
448
  embedding_dim=768,
449
  max_token_size=8192,
450
+ func=lambda texts: ollama_embed(
451
  texts,
452
  embed_model="nomic-embed-text"
453
  )
examples/lightrag_ollama_age_demo.py DELETED
@@ -1,113 +0,0 @@
1
- import asyncio
2
- import nest_asyncio
3
-
4
- import inspect
5
- import logging
6
- import os
7
-
8
- from lightrag import LightRAG, QueryParam
9
- from lightrag.llm.ollama import ollama_embed, ollama_model_complete
10
- from lightrag.utils import EmbeddingFunc
11
- from lightrag.kg.shared_storage import initialize_pipeline_status
12
-
13
- nest_asyncio.apply()
14
-
15
- WORKING_DIR = "./dickens_age"
16
-
17
- logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
18
-
19
- if not os.path.exists(WORKING_DIR):
20
- os.mkdir(WORKING_DIR)
21
-
22
- # AGE
23
- os.environ["AGE_POSTGRES_DB"] = "postgresDB"
24
- os.environ["AGE_POSTGRES_USER"] = "postgresUser"
25
- os.environ["AGE_POSTGRES_PASSWORD"] = "postgresPW"
26
- os.environ["AGE_POSTGRES_HOST"] = "localhost"
27
- os.environ["AGE_POSTGRES_PORT"] = "5455"
28
- os.environ["AGE_GRAPH_NAME"] = "dickens"
29
-
30
-
31
- async def initialize_rag():
32
- rag = LightRAG(
33
- working_dir=WORKING_DIR,
34
- llm_model_func=ollama_model_complete,
35
- llm_model_name="llama3.1:8b",
36
- llm_model_max_async=4,
37
- llm_model_max_token_size=32768,
38
- llm_model_kwargs={
39
- "host": "http://localhost:11434",
40
- "options": {"num_ctx": 32768},
41
- },
42
- embedding_func=EmbeddingFunc(
43
- embedding_dim=768,
44
- max_token_size=8192,
45
- func=lambda texts: ollama_embed(
46
- texts, embed_model="nomic-embed-text", host="http://localhost:11434"
47
- ),
48
- ),
49
- graph_storage="AGEStorage",
50
- )
51
-
52
- await rag.initialize_storages()
53
- await initialize_pipeline_status()
54
-
55
- return rag
56
-
57
-
58
- async def print_stream(stream):
59
- async for chunk in stream:
60
- print(chunk, end="", flush=True)
61
-
62
-
63
- def main():
64
- # Initialize RAG instance
65
- rag = asyncio.run(initialize_rag())
66
-
67
- # Insert example text
68
- with open("./book.txt", "r", encoding="utf-8") as f:
69
- rag.insert(f.read())
70
-
71
- # Test different query modes
72
- print("\nNaive Search:")
73
- print(
74
- rag.query(
75
- "What are the top themes in this story?", param=QueryParam(mode="naive")
76
- )
77
- )
78
-
79
- print("\nLocal Search:")
80
- print(
81
- rag.query(
82
- "What are the top themes in this story?", param=QueryParam(mode="local")
83
- )
84
- )
85
-
86
- print("\nGlobal Search:")
87
- print(
88
- rag.query(
89
- "What are the top themes in this story?", param=QueryParam(mode="global")
90
- )
91
- )
92
-
93
- print("\nHybrid Search:")
94
- print(
95
- rag.query(
96
- "What are the top themes in this story?", param=QueryParam(mode="hybrid")
97
- )
98
- )
99
-
100
- # stream response
101
- resp = rag.query(
102
- "What are the top themes in this story?",
103
- param=QueryParam(mode="hybrid", stream=True),
104
- )
105
-
106
- if inspect.isasyncgen(resp):
107
- asyncio.run(print_stream(resp))
108
- else:
109
- print(resp)
110
-
111
-
112
- if __name__ == "__main__":
113
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_ollama_demo.py CHANGED
@@ -1,19 +1,82 @@
1
  import asyncio
2
- import nest_asyncio
3
-
4
  import os
5
  import inspect
6
  import logging
 
7
  from lightrag import LightRAG, QueryParam
8
  from lightrag.llm.ollama import ollama_model_complete, ollama_embed
9
- from lightrag.utils import EmbeddingFunc
10
  from lightrag.kg.shared_storage import initialize_pipeline_status
11
 
12
- nest_asyncio.apply()
 
 
13
 
14
  WORKING_DIR = "./dickens"
15
 
16
- logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  if not os.path.exists(WORKING_DIR):
19
  os.mkdir(WORKING_DIR)
@@ -23,18 +86,20 @@ async def initialize_rag():
23
  rag = LightRAG(
24
  working_dir=WORKING_DIR,
25
  llm_model_func=ollama_model_complete,
26
- llm_model_name="gemma2:2b",
27
- llm_model_max_async=4,
28
- llm_model_max_token_size=32768,
29
  llm_model_kwargs={
30
- "host": "http://localhost:11434",
31
- "options": {"num_ctx": 32768},
 
32
  },
33
  embedding_func=EmbeddingFunc(
34
- embedding_dim=768,
35
- max_token_size=8192,
36
  func=lambda texts: ollama_embed(
37
- texts, embed_model="nomic-embed-text", host="http://localhost:11434"
 
 
38
  ),
39
  ),
40
  )
@@ -50,54 +115,103 @@ async def print_stream(stream):
50
  print(chunk, end="", flush=True)
51
 
52
 
53
- def main():
54
- # Initialize RAG instance
55
- rag = asyncio.run(initialize_rag())
56
-
57
- # Insert example text
58
- with open("./book.txt", "r", encoding="utf-8") as f:
59
- rag.insert(f.read())
60
-
61
- # Test different query modes
62
- print("\nNaive Search:")
63
- print(
64
- rag.query(
65
- "What are the top themes in this story?", param=QueryParam(mode="naive")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  )
67
- )
68
-
69
- print("\nLocal Search:")
70
- print(
71
- rag.query(
72
- "What are the top themes in this story?", param=QueryParam(mode="local")
 
 
 
 
 
 
73
  )
74
- )
75
-
76
- print("\nGlobal Search:")
77
- print(
78
- rag.query(
79
- "What are the top themes in this story?", param=QueryParam(mode="global")
 
 
 
 
 
 
80
  )
81
- )
82
-
83
- print("\nHybrid Search:")
84
- print(
85
- rag.query(
86
- "What are the top themes in this story?", param=QueryParam(mode="hybrid")
 
 
 
 
 
 
87
  )
88
- )
89
-
90
- # stream response
91
- resp = rag.query(
92
- "What are the top themes in this story?",
93
- param=QueryParam(mode="hybrid", stream=True),
94
- )
95
 
96
- if inspect.isasyncgen(resp):
97
- asyncio.run(print_stream(resp))
98
- else:
99
- print(resp)
 
 
100
 
101
 
102
  if __name__ == "__main__":
103
- main()
 
 
 
 
1
  import asyncio
 
 
2
  import os
3
  import inspect
4
  import logging
5
+ import logging.config
6
  from lightrag import LightRAG, QueryParam
7
  from lightrag.llm.ollama import ollama_model_complete, ollama_embed
8
+ from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
9
  from lightrag.kg.shared_storage import initialize_pipeline_status
10
 
11
+ from dotenv import load_dotenv
12
+
13
+ load_dotenv(dotenv_path=".env", override=False)
14
 
15
  WORKING_DIR = "./dickens"
16
 
17
+
18
+ def configure_logging():
19
+ """Configure logging for the application"""
20
+
21
+ # Reset any existing handlers to ensure clean configuration
22
+ for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
23
+ logger_instance = logging.getLogger(logger_name)
24
+ logger_instance.handlers = []
25
+ logger_instance.filters = []
26
+
27
+ # Get log directory path from environment variable or use current directory
28
+ log_dir = os.getenv("LOG_DIR", os.getcwd())
29
+ log_file_path = os.path.abspath(os.path.join(log_dir, "lightrag_ollama_demo.log"))
30
+
31
+ print(f"\nLightRAG compatible demo log file: {log_file_path}\n")
32
+ os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
33
+
34
+ # Get log file max size and backup count from environment variables
35
+ log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
36
+ log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
37
+
38
+ logging.config.dictConfig(
39
+ {
40
+ "version": 1,
41
+ "disable_existing_loggers": False,
42
+ "formatters": {
43
+ "default": {
44
+ "format": "%(levelname)s: %(message)s",
45
+ },
46
+ "detailed": {
47
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
48
+ },
49
+ },
50
+ "handlers": {
51
+ "console": {
52
+ "formatter": "default",
53
+ "class": "logging.StreamHandler",
54
+ "stream": "ext://sys.stderr",
55
+ },
56
+ "file": {
57
+ "formatter": "detailed",
58
+ "class": "logging.handlers.RotatingFileHandler",
59
+ "filename": log_file_path,
60
+ "maxBytes": log_max_bytes,
61
+ "backupCount": log_backup_count,
62
+ "encoding": "utf-8",
63
+ },
64
+ },
65
+ "loggers": {
66
+ "lightrag": {
67
+ "handlers": ["console", "file"],
68
+ "level": "INFO",
69
+ "propagate": False,
70
+ },
71
+ },
72
+ }
73
+ )
74
+
75
+ # Set the logger level to INFO
76
+ logger.setLevel(logging.INFO)
77
+ # Enable verbose debug if needed
78
+ set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
79
+
80
 
81
  if not os.path.exists(WORKING_DIR):
82
  os.mkdir(WORKING_DIR)
 
86
  rag = LightRAG(
87
  working_dir=WORKING_DIR,
88
  llm_model_func=ollama_model_complete,
89
+ llm_model_name=os.getenv("LLM_MODEL", "qwen2.5-coder:7b"),
90
+ llm_model_max_token_size=8192,
 
91
  llm_model_kwargs={
92
+ "host": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
93
+ "options": {"num_ctx": 8192},
94
+ "timeout": int(os.getenv("TIMEOUT", "300")),
95
  },
96
  embedding_func=EmbeddingFunc(
97
+ embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
98
+ max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
99
  func=lambda texts: ollama_embed(
100
+ texts,
101
+ embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"),
102
+ host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"),
103
  ),
104
  ),
105
  )
 
115
  print(chunk, end="", flush=True)
116
 
117
 
118
+ async def main():
119
+ try:
120
+ # Clear old data files
121
+ files_to_delete = [
122
+ "graph_chunk_entity_relation.graphml",
123
+ "kv_store_doc_status.json",
124
+ "kv_store_full_docs.json",
125
+ "kv_store_text_chunks.json",
126
+ "vdb_chunks.json",
127
+ "vdb_entities.json",
128
+ "vdb_relationships.json",
129
+ ]
130
+
131
+ for file in files_to_delete:
132
+ file_path = os.path.join(WORKING_DIR, file)
133
+ if os.path.exists(file_path):
134
+ os.remove(file_path)
135
+ print(f"Deleting old file:: {file_path}")
136
+
137
+ # Initialize RAG instance
138
+ rag = await initialize_rag()
139
+
140
+ # Test embedding function
141
+ test_text = ["This is a test string for embedding."]
142
+ embedding = await rag.embedding_func(test_text)
143
+ embedding_dim = embedding.shape[1]
144
+ print("\n=======================")
145
+ print("Test embedding function")
146
+ print("========================")
147
+ print(f"Test dict: {test_text}")
148
+ print(f"Detected embedding dimension: {embedding_dim}\n\n")
149
+
150
+ with open("./book.txt", "r", encoding="utf-8") as f:
151
+ await rag.ainsert(f.read())
152
+
153
+ # Perform naive search
154
+ print("\n=====================")
155
+ print("Query mode: naive")
156
+ print("=====================")
157
+ resp = await rag.aquery(
158
+ "What are the top themes in this story?",
159
+ param=QueryParam(mode="naive", stream=True),
160
  )
161
+ if inspect.isasyncgen(resp):
162
+ await print_stream(resp)
163
+ else:
164
+ print(resp)
165
+
166
+ # Perform local search
167
+ print("\n=====================")
168
+ print("Query mode: local")
169
+ print("=====================")
170
+ resp = await rag.aquery(
171
+ "What are the top themes in this story?",
172
+ param=QueryParam(mode="local", stream=True),
173
  )
174
+ if inspect.isasyncgen(resp):
175
+ await print_stream(resp)
176
+ else:
177
+ print(resp)
178
+
179
+ # Perform global search
180
+ print("\n=====================")
181
+ print("Query mode: global")
182
+ print("=====================")
183
+ resp = await rag.aquery(
184
+ "What are the top themes in this story?",
185
+ param=QueryParam(mode="global", stream=True),
186
  )
187
+ if inspect.isasyncgen(resp):
188
+ await print_stream(resp)
189
+ else:
190
+ print(resp)
191
+
192
+ # Perform hybrid search
193
+ print("\n=====================")
194
+ print("Query mode: hybrid")
195
+ print("=====================")
196
+ resp = await rag.aquery(
197
+ "What are the top themes in this story?",
198
+ param=QueryParam(mode="hybrid", stream=True),
199
  )
200
+ if inspect.isasyncgen(resp):
201
+ await print_stream(resp)
202
+ else:
203
+ print(resp)
 
 
 
204
 
205
+ except Exception as e:
206
+ print(f"An error occurred: {e}")
207
+ finally:
208
+ if rag:
209
+ await rag.llm_response_cache.index_done_callback()
210
+ await rag.finalize_storages()
211
 
212
 
213
  if __name__ == "__main__":
214
+ # Configure logging before running the main function
215
+ configure_logging()
216
+ asyncio.run(main())
217
+ print("\nDone!")
examples/lightrag_ollama_gremlin_demo.py DELETED
@@ -1,122 +0,0 @@
1
- ##############################################
2
- # Gremlin storage implementation is deprecated
3
- ##############################################
4
-
5
- import asyncio
6
- import inspect
7
- import os
8
-
9
- # Uncomment these lines below to filter out somewhat verbose INFO level
10
- # logging prints (the default loglevel is INFO).
11
- # This has to go before the lightrag imports to work,
12
- # which triggers linting errors, so we keep it commented out:
13
- # import logging
14
- # logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARN)
15
-
16
- from lightrag import LightRAG, QueryParam
17
- from lightrag.llm.ollama import ollama_embed, ollama_model_complete
18
- from lightrag.utils import EmbeddingFunc
19
- from lightrag.kg.shared_storage import initialize_pipeline_status
20
-
21
- WORKING_DIR = "./dickens_gremlin"
22
-
23
- if not os.path.exists(WORKING_DIR):
24
- os.mkdir(WORKING_DIR)
25
-
26
- # Gremlin
27
- os.environ["GREMLIN_HOST"] = "localhost"
28
- os.environ["GREMLIN_PORT"] = "8182"
29
- os.environ["GREMLIN_GRAPH"] = "dickens"
30
-
31
- # Creating a non-default source requires manual
32
- # configuration and a restart on the server: use the dafault "g"
33
- os.environ["GREMLIN_TRAVERSE_SOURCE"] = "g"
34
-
35
- # No authorization by default on docker tinkerpop/gremlin-server
36
- os.environ["GREMLIN_USER"] = ""
37
- os.environ["GREMLIN_PASSWORD"] = ""
38
-
39
-
40
- async def initialize_rag():
41
- rag = LightRAG(
42
- working_dir=WORKING_DIR,
43
- llm_model_func=ollama_model_complete,
44
- llm_model_name="llama3.1:8b",
45
- llm_model_max_async=4,
46
- llm_model_max_token_size=32768,
47
- llm_model_kwargs={
48
- "host": "http://localhost:11434",
49
- "options": {"num_ctx": 32768},
50
- },
51
- embedding_func=EmbeddingFunc(
52
- embedding_dim=768,
53
- max_token_size=8192,
54
- func=lambda texts: ollama_embed(
55
- texts, embed_model="nomic-embed-text", host="http://localhost:11434"
56
- ),
57
- ),
58
- graph_storage="GremlinStorage",
59
- )
60
-
61
- await rag.initialize_storages()
62
- await initialize_pipeline_status()
63
-
64
- return rag
65
-
66
-
67
- async def print_stream(stream):
68
- async for chunk in stream:
69
- print(chunk, end="", flush=True)
70
-
71
-
72
- def main():
73
- # Initialize RAG instance
74
- rag = asyncio.run(initialize_rag())
75
-
76
- # Insert example text
77
- with open("./book.txt", "r", encoding="utf-8") as f:
78
- rag.insert(f.read())
79
-
80
- # Test different query modes
81
- print("\nNaive Search:")
82
- print(
83
- rag.query(
84
- "What are the top themes in this story?", param=QueryParam(mode="naive")
85
- )
86
- )
87
-
88
- print("\nLocal Search:")
89
- print(
90
- rag.query(
91
- "What are the top themes in this story?", param=QueryParam(mode="local")
92
- )
93
- )
94
-
95
- print("\nGlobal Search:")
96
- print(
97
- rag.query(
98
- "What are the top themes in this story?", param=QueryParam(mode="global")
99
- )
100
- )
101
-
102
- print("\nHybrid Search:")
103
- print(
104
- rag.query(
105
- "What are the top themes in this story?", param=QueryParam(mode="hybrid")
106
- )
107
- )
108
-
109
- # stream response
110
- resp = rag.query(
111
- "What are the top themes in this story?",
112
- param=QueryParam(mode="hybrid", stream=True),
113
- )
114
-
115
- if inspect.isasyncgen(resp):
116
- asyncio.run(print_stream(resp))
117
- else:
118
- print(resp)
119
-
120
-
121
- if __name__ == "__main__":
122
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_ollama_neo4j_milvus_mongo_demo.py DELETED
@@ -1,104 +0,0 @@
1
- import os
2
- from lightrag import LightRAG, QueryParam
3
- from lightrag.llm.ollama import ollama_model_complete, ollama_embed
4
- from lightrag.utils import EmbeddingFunc
5
- import asyncio
6
- import nest_asyncio
7
-
8
- nest_asyncio.apply()
9
- from lightrag.kg.shared_storage import initialize_pipeline_status
10
-
11
- # WorkingDir
12
- ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
13
- WORKING_DIR = os.path.join(ROOT_DIR, "myKG")
14
- if not os.path.exists(WORKING_DIR):
15
- os.mkdir(WORKING_DIR)
16
- print(f"WorkingDir: {WORKING_DIR}")
17
-
18
- # mongo
19
- os.environ["MONGO_URI"] = "mongodb://root:root@localhost:27017/"
20
- os.environ["MONGO_DATABASE"] = "LightRAG"
21
-
22
- # neo4j
23
- BATCH_SIZE_NODES = 500
24
- BATCH_SIZE_EDGES = 100
25
- os.environ["NEO4J_URI"] = "bolt://localhost:7687"
26
- os.environ["NEO4J_USERNAME"] = "neo4j"
27
- os.environ["NEO4J_PASSWORD"] = "neo4j"
28
-
29
- # milvus
30
- os.environ["MILVUS_URI"] = "http://localhost:19530"
31
- os.environ["MILVUS_USER"] = "root"
32
- os.environ["MILVUS_PASSWORD"] = "root"
33
- os.environ["MILVUS_DB_NAME"] = "lightrag"
34
-
35
-
36
- async def initialize_rag():
37
- rag = LightRAG(
38
- working_dir=WORKING_DIR,
39
- llm_model_func=ollama_model_complete,
40
- llm_model_name="qwen2.5:14b",
41
- llm_model_max_async=4,
42
- llm_model_max_token_size=32768,
43
- llm_model_kwargs={
44
- "host": "http://127.0.0.1:11434",
45
- "options": {"num_ctx": 32768},
46
- },
47
- embedding_func=EmbeddingFunc(
48
- embedding_dim=1024,
49
- max_token_size=8192,
50
- func=lambda texts: ollama_embed(
51
- texts=texts, embed_model="bge-m3:latest", host="http://127.0.0.1:11434"
52
- ),
53
- ),
54
- kv_storage="MongoKVStorage",
55
- graph_storage="Neo4JStorage",
56
- vector_storage="MilvusVectorDBStorage",
57
- )
58
-
59
- await rag.initialize_storages()
60
- await initialize_pipeline_status()
61
-
62
- return rag
63
-
64
-
65
- def main():
66
- # Initialize RAG instance
67
- rag = asyncio.run(initialize_rag())
68
-
69
- # Insert example text
70
- with open("./book.txt", "r", encoding="utf-8") as f:
71
- rag.insert(f.read())
72
-
73
- # Test different query modes
74
- print("\nNaive Search:")
75
- print(
76
- rag.query(
77
- "What are the top themes in this story?", param=QueryParam(mode="naive")
78
- )
79
- )
80
-
81
- print("\nLocal Search:")
82
- print(
83
- rag.query(
84
- "What are the top themes in this story?", param=QueryParam(mode="local")
85
- )
86
- )
87
-
88
- print("\nGlobal Search:")
89
- print(
90
- rag.query(
91
- "What are the top themes in this story?", param=QueryParam(mode="global")
92
- )
93
- )
94
-
95
- print("\nHybrid Search:")
96
- print(
97
- rag.query(
98
- "What are the top themes in this story?", param=QueryParam(mode="hybrid")
99
- )
100
- )
101
-
102
-
103
- if __name__ == "__main__":
104
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_openai_compatible_demo_embedding_cache.py DELETED
@@ -1,123 +0,0 @@
1
- import os
2
- import asyncio
3
- from lightrag import LightRAG, QueryParam
4
- from lightrag.llm.openai import openai_complete_if_cache, openai_embed
5
- from lightrag.utils import EmbeddingFunc
6
- import numpy as np
7
- from lightrag.kg.shared_storage import initialize_pipeline_status
8
-
9
- WORKING_DIR = "./dickens"
10
-
11
- if not os.path.exists(WORKING_DIR):
12
- os.mkdir(WORKING_DIR)
13
-
14
-
15
- async def llm_model_func(
16
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
17
- ) -> str:
18
- return await openai_complete_if_cache(
19
- "solar-mini",
20
- prompt,
21
- system_prompt=system_prompt,
22
- history_messages=history_messages,
23
- api_key=os.getenv("UPSTAGE_API_KEY"),
24
- base_url="https://api.upstage.ai/v1/solar",
25
- **kwargs,
26
- )
27
-
28
-
29
- async def embedding_func(texts: list[str]) -> np.ndarray:
30
- return await openai_embed(
31
- texts,
32
- model="solar-embedding-1-large-query",
33
- api_key=os.getenv("UPSTAGE_API_KEY"),
34
- base_url="https://api.upstage.ai/v1/solar",
35
- )
36
-
37
-
38
- async def get_embedding_dim():
39
- test_text = ["This is a test sentence."]
40
- embedding = await embedding_func(test_text)
41
- embedding_dim = embedding.shape[1]
42
- return embedding_dim
43
-
44
-
45
- # function test
46
- async def test_funcs():
47
- result = await llm_model_func("How are you?")
48
- print("llm_model_func: ", result)
49
-
50
- result = await embedding_func(["How are you?"])
51
- print("embedding_func: ", result)
52
-
53
-
54
- # asyncio.run(test_funcs())
55
-
56
-
57
- async def initialize_rag():
58
- embedding_dimension = await get_embedding_dim()
59
- print(f"Detected embedding dimension: {embedding_dimension}")
60
-
61
- rag = LightRAG(
62
- working_dir=WORKING_DIR,
63
- embedding_cache_config={
64
- "enabled": True,
65
- "similarity_threshold": 0.90,
66
- },
67
- llm_model_func=llm_model_func,
68
- embedding_func=EmbeddingFunc(
69
- embedding_dim=embedding_dimension,
70
- max_token_size=8192,
71
- func=embedding_func,
72
- ),
73
- )
74
-
75
- await rag.initialize_storages()
76
- await initialize_pipeline_status()
77
-
78
- return rag
79
-
80
-
81
- async def main():
82
- try:
83
- # Initialize RAG instance
84
- rag = await initialize_rag()
85
-
86
- with open("./book.txt", "r", encoding="utf-8") as f:
87
- await rag.ainsert(f.read())
88
-
89
- # Perform naive search
90
- print(
91
- await rag.aquery(
92
- "What are the top themes in this story?", param=QueryParam(mode="naive")
93
- )
94
- )
95
-
96
- # Perform local search
97
- print(
98
- await rag.aquery(
99
- "What are the top themes in this story?", param=QueryParam(mode="local")
100
- )
101
- )
102
-
103
- # Perform global search
104
- print(
105
- await rag.aquery(
106
- "What are the top themes in this story?",
107
- param=QueryParam(mode="global"),
108
- )
109
- )
110
-
111
- # Perform hybrid search
112
- print(
113
- await rag.aquery(
114
- "What are the top themes in this story?",
115
- param=QueryParam(mode="hybrid"),
116
- )
117
- )
118
- except Exception as e:
119
- print(f"An error occurred: {e}")
120
-
121
-
122
- if __name__ == "__main__":
123
- asyncio.run(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_siliconcloud_demo.py DELETED
@@ -1,103 +0,0 @@
1
- import os
2
- import asyncio
3
- from lightrag import LightRAG, QueryParam
4
- from lightrag.llm.openai import openai_complete_if_cache
5
- from lightrag.llm.siliconcloud import siliconcloud_embedding
6
- from lightrag.utils import EmbeddingFunc
7
- import numpy as np
8
- from lightrag.kg.shared_storage import initialize_pipeline_status
9
-
10
- WORKING_DIR = "./dickens"
11
-
12
- if not os.path.exists(WORKING_DIR):
13
- os.mkdir(WORKING_DIR)
14
-
15
-
16
- async def llm_model_func(
17
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
18
- ) -> str:
19
- return await openai_complete_if_cache(
20
- "Qwen/Qwen2.5-7B-Instruct",
21
- prompt,
22
- system_prompt=system_prompt,
23
- history_messages=history_messages,
24
- api_key=os.getenv("SILICONFLOW_API_KEY"),
25
- base_url="https://api.siliconflow.cn/v1/",
26
- **kwargs,
27
- )
28
-
29
-
30
- async def embedding_func(texts: list[str]) -> np.ndarray:
31
- return await siliconcloud_embedding(
32
- texts,
33
- model="netease-youdao/bce-embedding-base_v1",
34
- api_key=os.getenv("SILICONFLOW_API_KEY"),
35
- max_token_size=512,
36
- )
37
-
38
-
39
- # function test
40
- async def test_funcs():
41
- result = await llm_model_func("How are you?")
42
- print("llm_model_func: ", result)
43
-
44
- result = await embedding_func(["How are you?"])
45
- print("embedding_func: ", result)
46
-
47
-
48
- asyncio.run(test_funcs())
49
-
50
-
51
- async def initialize_rag():
52
- rag = LightRAG(
53
- working_dir=WORKING_DIR,
54
- llm_model_func=llm_model_func,
55
- embedding_func=EmbeddingFunc(
56
- embedding_dim=768, max_token_size=512, func=embedding_func
57
- ),
58
- )
59
-
60
- await rag.initialize_storages()
61
- await initialize_pipeline_status()
62
-
63
- return rag
64
-
65
-
66
- def main():
67
- # Initialize RAG instance
68
- rag = asyncio.run(initialize_rag())
69
-
70
- with open("./book.txt", "r", encoding="utf-8") as f:
71
- rag.insert(f.read())
72
-
73
- # Perform naive search
74
- print(
75
- rag.query(
76
- "What are the top themes in this story?", param=QueryParam(mode="naive")
77
- )
78
- )
79
-
80
- # Perform local search
81
- print(
82
- rag.query(
83
- "What are the top themes in this story?", param=QueryParam(mode="local")
84
- )
85
- )
86
-
87
- # Perform global search
88
- print(
89
- rag.query(
90
- "What are the top themes in this story?", param=QueryParam(mode="global")
91
- )
92
- )
93
-
94
- # Perform hybrid search
95
- print(
96
- rag.query(
97
- "What are the top themes in this story?", param=QueryParam(mode="hybrid")
98
- )
99
- )
100
-
101
-
102
- if __name__ == "__main__":
103
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_siliconcloud_track_token_demo.py DELETED
@@ -1,110 +0,0 @@
1
- import os
2
- import asyncio
3
- from lightrag import LightRAG, QueryParam
4
- from lightrag.llm.openai import openai_complete_if_cache
5
- from lightrag.llm.siliconcloud import siliconcloud_embedding
6
- from lightrag.utils import EmbeddingFunc
7
- from lightrag.utils import TokenTracker
8
- import numpy as np
9
- from lightrag.kg.shared_storage import initialize_pipeline_status
10
- from dotenv import load_dotenv
11
-
12
- load_dotenv()
13
-
14
- token_tracker = TokenTracker()
15
- WORKING_DIR = "./dickens"
16
-
17
- if not os.path.exists(WORKING_DIR):
18
- os.mkdir(WORKING_DIR)
19
-
20
-
21
- async def llm_model_func(
22
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
23
- ) -> str:
24
- return await openai_complete_if_cache(
25
- "Qwen/Qwen2.5-7B-Instruct",
26
- prompt,
27
- system_prompt=system_prompt,
28
- history_messages=history_messages,
29
- api_key=os.getenv("SILICONFLOW_API_KEY"),
30
- base_url="https://api.siliconflow.cn/v1/",
31
- token_tracker=token_tracker,
32
- **kwargs,
33
- )
34
-
35
-
36
- async def embedding_func(texts: list[str]) -> np.ndarray:
37
- return await siliconcloud_embedding(
38
- texts,
39
- model="BAAI/bge-m3",
40
- api_key=os.getenv("SILICONFLOW_API_KEY"),
41
- max_token_size=512,
42
- )
43
-
44
-
45
- # function test
46
- async def test_funcs():
47
- # Context Manager Method
48
- with token_tracker:
49
- result = await llm_model_func("How are you?")
50
- print("llm_model_func: ", result)
51
-
52
-
53
- asyncio.run(test_funcs())
54
-
55
-
56
- async def initialize_rag():
57
- rag = LightRAG(
58
- working_dir=WORKING_DIR,
59
- llm_model_func=llm_model_func,
60
- embedding_func=EmbeddingFunc(
61
- embedding_dim=1024, max_token_size=512, func=embedding_func
62
- ),
63
- )
64
-
65
- await rag.initialize_storages()
66
- await initialize_pipeline_status()
67
-
68
- return rag
69
-
70
-
71
- def main():
72
- # Initialize RAG instance
73
- rag = asyncio.run(initialize_rag())
74
-
75
- # Reset tracker before processing queries
76
- token_tracker.reset()
77
-
78
- with open("./book.txt", "r", encoding="utf-8") as f:
79
- rag.insert(f.read())
80
-
81
- print(
82
- rag.query(
83
- "What are the top themes in this story?", param=QueryParam(mode="naive")
84
- )
85
- )
86
-
87
- print(
88
- rag.query(
89
- "What are the top themes in this story?", param=QueryParam(mode="local")
90
- )
91
- )
92
-
93
- print(
94
- rag.query(
95
- "What are the top themes in this story?", param=QueryParam(mode="global")
96
- )
97
- )
98
-
99
- print(
100
- rag.query(
101
- "What are the top themes in this story?", param=QueryParam(mode="hybrid")
102
- )
103
- )
104
-
105
- # Display final token usage after main query
106
- print("Token usage:", token_tracker.get_usage())
107
-
108
-
109
- if __name__ == "__main__":
110
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_tidb_demo.py DELETED
@@ -1,116 +0,0 @@
1
- ###########################################
2
- # TiDB storage implementation is deprecated
3
- ###########################################
4
-
5
- import asyncio
6
- import os
7
-
8
- import numpy as np
9
-
10
- from lightrag import LightRAG, QueryParam
11
- from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache
12
- from lightrag.utils import EmbeddingFunc
13
- from lightrag.kg.shared_storage import initialize_pipeline_status
14
-
15
- WORKING_DIR = "./dickens"
16
-
17
- # We use SiliconCloud API to call LLM on Oracle Cloud
18
- # More docs here https://docs.siliconflow.cn/introduction
19
- BASE_URL = "https://api.siliconflow.cn/v1/"
20
- APIKEY = ""
21
- CHATMODEL = ""
22
- EMBEDMODEL = ""
23
-
24
- os.environ["TIDB_HOST"] = ""
25
- os.environ["TIDB_PORT"] = ""
26
- os.environ["TIDB_USER"] = ""
27
- os.environ["TIDB_PASSWORD"] = ""
28
- os.environ["TIDB_DATABASE"] = "lightrag"
29
-
30
- if not os.path.exists(WORKING_DIR):
31
- os.mkdir(WORKING_DIR)
32
-
33
-
34
- async def llm_model_func(
35
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
36
- ) -> str:
37
- return await openai_complete_if_cache(
38
- CHATMODEL,
39
- prompt,
40
- system_prompt=system_prompt,
41
- history_messages=history_messages,
42
- api_key=APIKEY,
43
- base_url=BASE_URL,
44
- **kwargs,
45
- )
46
-
47
-
48
- async def embedding_func(texts: list[str]) -> np.ndarray:
49
- return await siliconcloud_embedding(
50
- texts,
51
- # model=EMBEDMODEL,
52
- api_key=APIKEY,
53
- )
54
-
55
-
56
- async def get_embedding_dim():
57
- test_text = ["This is a test sentence."]
58
- embedding = await embedding_func(test_text)
59
- embedding_dim = embedding.shape[1]
60
- return embedding_dim
61
-
62
-
63
- async def initialize_rag():
64
- # Detect embedding dimension
65
- embedding_dimension = await get_embedding_dim()
66
- print(f"Detected embedding dimension: {embedding_dimension}")
67
-
68
- # Initialize LightRAG
69
- # We use TiDB DB as the KV/vector
70
- rag = LightRAG(
71
- enable_llm_cache=False,
72
- working_dir=WORKING_DIR,
73
- chunk_token_size=512,
74
- llm_model_func=llm_model_func,
75
- embedding_func=EmbeddingFunc(
76
- embedding_dim=embedding_dimension,
77
- max_token_size=512,
78
- func=embedding_func,
79
- ),
80
- kv_storage="TiDBKVStorage",
81
- vector_storage="TiDBVectorDBStorage",
82
- graph_storage="TiDBGraphStorage",
83
- )
84
-
85
- await rag.initialize_storages()
86
- await initialize_pipeline_status()
87
-
88
- return rag
89
-
90
-
91
- async def main():
92
- try:
93
- # Initialize RAG instance
94
- rag = await initialize_rag()
95
-
96
- with open("./book.txt", "r", encoding="utf-8") as f:
97
- rag.insert(f.read())
98
-
99
- # Perform search in different modes
100
- modes = ["naive", "local", "global", "hybrid"]
101
- for mode in modes:
102
- print("=" * 20, mode, "=" * 20)
103
- print(
104
- await rag.aquery(
105
- "What are the top themes in this story?",
106
- param=QueryParam(mode=mode),
107
- )
108
- )
109
- print("-" * 100, "\n")
110
-
111
- except Exception as e:
112
- print(f"An error occurred: {e}")
113
-
114
-
115
- if __name__ == "__main__":
116
- asyncio.run(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_tongyi_openai_demo.py DELETED
@@ -1,136 +0,0 @@
1
- import os
2
- import asyncio
3
- from lightrag import LightRAG, QueryParam
4
- from lightrag.utils import EmbeddingFunc
5
- import numpy as np
6
- from dotenv import load_dotenv
7
- import logging
8
- from openai import OpenAI
9
- from lightrag.kg.shared_storage import initialize_pipeline_status
10
-
11
- logging.basicConfig(level=logging.INFO)
12
-
13
- load_dotenv()
14
-
15
- LLM_MODEL = os.environ.get("LLM_MODEL", "qwen-turbo-latest")
16
- LLM_BINDING_HOST = "https://dashscope.aliyuncs.com/compatible-mode/v1"
17
- LLM_BINDING_API_KEY = os.getenv("LLM_BINDING_API_KEY")
18
-
19
- EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-v3")
20
- EMBEDDING_BINDING_HOST = os.getenv("EMBEDDING_BINDING_HOST", LLM_BINDING_HOST)
21
- EMBEDDING_BINDING_API_KEY = os.getenv("EMBEDDING_BINDING_API_KEY", LLM_BINDING_API_KEY)
22
- EMBEDDING_DIM = int(os.environ.get("EMBEDDING_DIM", 1024))
23
- EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
24
- EMBEDDING_MAX_BATCH_SIZE = int(os.environ.get("EMBEDDING_MAX_BATCH_SIZE", 10))
25
-
26
- print(f"LLM_MODEL: {LLM_MODEL}")
27
- print(f"EMBEDDING_MODEL: {EMBEDDING_MODEL}")
28
-
29
- WORKING_DIR = "./dickens"
30
-
31
- if os.path.exists(WORKING_DIR):
32
- import shutil
33
-
34
- shutil.rmtree(WORKING_DIR)
35
-
36
- os.mkdir(WORKING_DIR)
37
-
38
-
39
- async def llm_model_func(
40
- prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
41
- ) -> str:
42
- client = OpenAI(
43
- api_key=LLM_BINDING_API_KEY,
44
- base_url=LLM_BINDING_HOST,
45
- )
46
-
47
- messages = []
48
- if system_prompt:
49
- messages.append({"role": "system", "content": system_prompt})
50
- if history_messages:
51
- messages.extend(history_messages)
52
- messages.append({"role": "user", "content": prompt})
53
-
54
- chat_completion = client.chat.completions.create(
55
- model=LLM_MODEL,
56
- messages=messages,
57
- temperature=kwargs.get("temperature", 0),
58
- top_p=kwargs.get("top_p", 1),
59
- n=kwargs.get("n", 1),
60
- extra_body={"enable_thinking": False},
61
- )
62
- return chat_completion.choices[0].message.content
63
-
64
-
65
- async def embedding_func(texts: list[str]) -> np.ndarray:
66
- client = OpenAI(
67
- api_key=EMBEDDING_BINDING_API_KEY,
68
- base_url=EMBEDDING_BINDING_HOST,
69
- )
70
-
71
- print("##### embedding: texts: %d #####" % len(texts))
72
- max_batch_size = EMBEDDING_MAX_BATCH_SIZE
73
- embeddings = []
74
- for i in range(0, len(texts), max_batch_size):
75
- batch = texts[i : i + max_batch_size]
76
- embedding = client.embeddings.create(model=EMBEDDING_MODEL, input=batch)
77
- embeddings += [item.embedding for item in embedding.data]
78
-
79
- return np.array(embeddings)
80
-
81
-
82
- async def test_funcs():
83
- result = await llm_model_func("How are you?")
84
- print("Resposta do llm_model_func: ", result)
85
-
86
- result = await embedding_func(["How are you?"])
87
- print("Resultado do embedding_func: ", result.shape)
88
- print("Dimensão da embedding: ", result.shape[1])
89
-
90
-
91
- asyncio.run(test_funcs())
92
-
93
-
94
- async def initialize_rag():
95
- rag = LightRAG(
96
- working_dir=WORKING_DIR,
97
- llm_model_func=llm_model_func,
98
- embedding_func=EmbeddingFunc(
99
- embedding_dim=EMBEDDING_DIM,
100
- max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
101
- func=embedding_func,
102
- ),
103
- )
104
-
105
- await rag.initialize_storages()
106
- await initialize_pipeline_status()
107
-
108
- return rag
109
-
110
-
111
- def main():
112
- rag = asyncio.run(initialize_rag())
113
-
114
- with open("./book.txt", "r", encoding="utf-8") as f:
115
- rag.insert(f.read())
116
-
117
- query_text = "What are the main themes?"
118
-
119
- print("Result (Naive):")
120
- print(rag.query(query_text, param=QueryParam(mode="naive")))
121
-
122
- print("\nResult (Local):")
123
- print(rag.query(query_text, param=QueryParam(mode="local")))
124
-
125
- print("\nResult (Global):")
126
- print(rag.query(query_text, param=QueryParam(mode="global")))
127
-
128
- print("\nResult (Hybrid):")
129
- print(rag.query(query_text, param=QueryParam(mode="hybrid")))
130
-
131
- print("\nResult (mix):")
132
- print(rag.query(query_text, param=QueryParam(mode="mix")))
133
-
134
-
135
- if __name__ == "__main__":
136
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_zhipu_demo.py DELETED
@@ -1,80 +0,0 @@
1
- import os
2
- import logging
3
- import asyncio
4
-
5
-
6
- from lightrag import LightRAG, QueryParam
7
- from lightrag.llm.zhipu import zhipu_complete, zhipu_embedding
8
- from lightrag.utils import EmbeddingFunc
9
- from lightrag.kg.shared_storage import initialize_pipeline_status
10
-
11
- WORKING_DIR = "./dickens"
12
-
13
- logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
14
-
15
- if not os.path.exists(WORKING_DIR):
16
- os.mkdir(WORKING_DIR)
17
-
18
- api_key = os.environ.get("ZHIPUAI_API_KEY")
19
- if api_key is None:
20
- raise Exception("Please set ZHIPU_API_KEY in your environment")
21
-
22
-
23
- async def initialize_rag():
24
- rag = LightRAG(
25
- working_dir=WORKING_DIR,
26
- llm_model_func=zhipu_complete,
27
- llm_model_name="glm-4-flashx", # Using the most cost/performance balance model, but you can change it here.
28
- llm_model_max_async=4,
29
- llm_model_max_token_size=32768,
30
- embedding_func=EmbeddingFunc(
31
- embedding_dim=2048, # Zhipu embedding-3 dimension
32
- max_token_size=8192,
33
- func=lambda texts: zhipu_embedding(texts),
34
- ),
35
- )
36
-
37
- await rag.initialize_storages()
38
- await initialize_pipeline_status()
39
-
40
- return rag
41
-
42
-
43
- def main():
44
- # Initialize RAG instance
45
- rag = asyncio.run(initialize_rag())
46
-
47
- with open("./book.txt", "r", encoding="utf-8") as f:
48
- rag.insert(f.read())
49
-
50
- # Perform naive search
51
- print(
52
- rag.query(
53
- "What are the top themes in this story?", param=QueryParam(mode="naive")
54
- )
55
- )
56
-
57
- # Perform local search
58
- print(
59
- rag.query(
60
- "What are the top themes in this story?", param=QueryParam(mode="local")
61
- )
62
- )
63
-
64
- # Perform global search
65
- print(
66
- rag.query(
67
- "What are the top themes in this story?", param=QueryParam(mode="global")
68
- )
69
- )
70
-
71
- # Perform hybrid search
72
- print(
73
- rag.query(
74
- "What are the top themes in this story?", param=QueryParam(mode="hybrid")
75
- )
76
- )
77
-
78
-
79
- if __name__ == "__main__":
80
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/lightrag_zhipu_postgres_demo.py DELETED
@@ -1,109 +0,0 @@
1
- import asyncio
2
- import logging
3
- import os
4
- import time
5
- from dotenv import load_dotenv
6
-
7
- from lightrag import LightRAG, QueryParam
8
- from lightrag.llm.zhipu import zhipu_complete
9
- from lightrag.llm.ollama import ollama_embedding
10
- from lightrag.utils import EmbeddingFunc
11
- from lightrag.kg.shared_storage import initialize_pipeline_status
12
-
13
- load_dotenv()
14
- ROOT_DIR = os.environ.get("ROOT_DIR")
15
- WORKING_DIR = f"{ROOT_DIR}/dickens-pg"
16
-
17
- logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
18
-
19
- if not os.path.exists(WORKING_DIR):
20
- os.mkdir(WORKING_DIR)
21
-
22
- # AGE
23
- os.environ["AGE_GRAPH_NAME"] = "dickens"
24
-
25
- os.environ["POSTGRES_HOST"] = "localhost"
26
- os.environ["POSTGRES_PORT"] = "15432"
27
- os.environ["POSTGRES_USER"] = "rag"
28
- os.environ["POSTGRES_PASSWORD"] = "rag"
29
- os.environ["POSTGRES_DATABASE"] = "rag"
30
-
31
-
32
- async def initialize_rag():
33
- rag = LightRAG(
34
- working_dir=WORKING_DIR,
35
- llm_model_func=zhipu_complete,
36
- llm_model_name="glm-4-flashx",
37
- llm_model_max_async=4,
38
- llm_model_max_token_size=32768,
39
- enable_llm_cache_for_entity_extract=True,
40
- embedding_func=EmbeddingFunc(
41
- embedding_dim=1024,
42
- max_token_size=8192,
43
- func=lambda texts: ollama_embedding(
44
- texts, embed_model="bge-m3", host="http://localhost:11434"
45
- ),
46
- ),
47
- kv_storage="PGKVStorage",
48
- doc_status_storage="PGDocStatusStorage",
49
- graph_storage="PGGraphStorage",
50
- vector_storage="PGVectorStorage",
51
- auto_manage_storages_states=False,
52
- )
53
-
54
- await rag.initialize_storages()
55
- await initialize_pipeline_status()
56
-
57
- return rag
58
-
59
-
60
- async def main():
61
- # Initialize RAG instance
62
- rag = await initialize_rag()
63
-
64
- # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
65
- rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
66
-
67
- with open(f"{ROOT_DIR}/book.txt", "r", encoding="utf-8") as f:
68
- await rag.ainsert(f.read())
69
-
70
- print("==== Trying to test the rag queries ====")
71
- print("**** Start Naive Query ****")
72
- start_time = time.time()
73
- # Perform naive search
74
- print(
75
- await rag.aquery(
76
- "What are the top themes in this story?", param=QueryParam(mode="naive")
77
- )
78
- )
79
- print(f"Naive Query Time: {time.time() - start_time} seconds")
80
- # Perform local search
81
- print("**** Start Local Query ****")
82
- start_time = time.time()
83
- print(
84
- await rag.aquery(
85
- "What are the top themes in this story?", param=QueryParam(mode="local")
86
- )
87
- )
88
- print(f"Local Query Time: {time.time() - start_time} seconds")
89
- # Perform global search
90
- print("**** Start Global Query ****")
91
- start_time = time.time()
92
- print(
93
- await rag.aquery(
94
- "What are the top themes in this story?", param=QueryParam(mode="global")
95
- )
96
- )
97
- print(f"Global Query Time: {time.time() - start_time}")
98
- # Perform hybrid search
99
- print("**** Start Hybrid Query ****")
100
- print(
101
- await rag.aquery(
102
- "What are the top themes in this story?", param=QueryParam(mode="hybrid")
103
- )
104
- )
105
- print(f"Hybrid Query Time: {time.time() - start_time} seconds")
106
-
107
-
108
- if __name__ == "__main__":
109
- asyncio.run(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lightrag/llm/ollama.py CHANGED
@@ -31,6 +31,7 @@ from lightrag.api import __api_version__
31
 
32
  import numpy as np
33
  from typing import Union
 
34
 
35
 
36
  @retry(
@@ -52,7 +53,7 @@ async def _ollama_model_if_cache(
52
  kwargs.pop("max_tokens", None)
53
  # kwargs.pop("response_format", None) # allow json
54
  host = kwargs.pop("host", None)
55
- timeout = kwargs.pop("timeout", None)
56
  kwargs.pop("hashing_kv", None)
57
  api_key = kwargs.pop("api_key", None)
58
  headers = {
@@ -61,32 +62,65 @@ async def _ollama_model_if_cache(
61
  }
62
  if api_key:
63
  headers["Authorization"] = f"Bearer {api_key}"
64
- ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
65
- messages = []
66
- if system_prompt:
67
- messages.append({"role": "system", "content": system_prompt})
68
- messages.extend(history_messages)
69
- messages.append({"role": "user", "content": prompt})
70
-
71
- response = await ollama_client.chat(model=model, messages=messages, **kwargs)
72
- if stream:
73
- """cannot cache stream response and process reasoning"""
74
-
75
- async def inner():
76
- async for chunk in response:
77
- yield chunk["message"]["content"]
78
-
79
- return inner()
80
- else:
81
- model_response = response["message"]["content"]
82
 
83
- """
84
- If the model also wraps its thoughts in a specific tag,
85
- this information is not needed for the final
86
- response and can simply be trimmed.
87
- """
88
 
89
- return model_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
 
92
  async def ollama_model_complete(
@@ -105,19 +139,6 @@ async def ollama_model_complete(
105
  )
106
 
107
 
108
- async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
109
- """
110
- Deprecated in favor of `embed`.
111
- """
112
- embed_text = []
113
- ollama_client = ollama.Client(**kwargs)
114
- for text in texts:
115
- data = ollama_client.embeddings(model=embed_model, prompt=text)
116
- embed_text.append(data["embedding"])
117
-
118
- return embed_text
119
-
120
-
121
  async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
122
  api_key = kwargs.pop("api_key", None)
123
  headers = {
@@ -125,8 +146,29 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
125
  "User-Agent": f"LightRAG/{__api_version__}",
126
  }
127
  if api_key:
128
- headers["Authorization"] = api_key
129
- kwargs["headers"] = headers
130
- ollama_client = ollama.Client(**kwargs)
131
- data = ollama_client.embed(model=embed_model, input=texts)
132
- return np.array(data["embeddings"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  import numpy as np
33
  from typing import Union
34
+ from lightrag.utils import logger
35
 
36
 
37
  @retry(
 
53
  kwargs.pop("max_tokens", None)
54
  # kwargs.pop("response_format", None) # allow json
55
  host = kwargs.pop("host", None)
56
+ timeout = kwargs.pop("timeout", None) or 300 # Default timeout 300s
57
  kwargs.pop("hashing_kv", None)
58
  api_key = kwargs.pop("api_key", None)
59
  headers = {
 
62
  }
63
  if api_key:
64
  headers["Authorization"] = f"Bearer {api_key}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
 
 
 
 
67
 
68
+ try:
69
+ messages = []
70
+ if system_prompt:
71
+ messages.append({"role": "system", "content": system_prompt})
72
+ messages.extend(history_messages)
73
+ messages.append({"role": "user", "content": prompt})
74
+
75
+ response = await ollama_client.chat(model=model, messages=messages, **kwargs)
76
+ if stream:
77
+ """cannot cache stream response and process reasoning"""
78
+
79
+ async def inner():
80
+ try:
81
+ async for chunk in response:
82
+ yield chunk["message"]["content"]
83
+ except Exception as e:
84
+ logger.error(f"Error in stream response: {str(e)}")
85
+ raise
86
+ finally:
87
+ try:
88
+ await ollama_client._client.aclose()
89
+ logger.debug("Successfully closed Ollama client for streaming")
90
+ except Exception as close_error:
91
+ logger.warning(f"Failed to close Ollama client: {close_error}")
92
+
93
+ return inner()
94
+ else:
95
+ model_response = response["message"]["content"]
96
+
97
+ """
98
+ If the model also wraps its thoughts in a specific tag,
99
+ this information is not needed for the final
100
+ response and can simply be trimmed.
101
+ """
102
+
103
+ return model_response
104
+ except Exception as e:
105
+ try:
106
+ await ollama_client._client.aclose()
107
+ logger.debug("Successfully closed Ollama client after exception")
108
+ except Exception as close_error:
109
+ logger.warning(
110
+ f"Failed to close Ollama client after exception: {close_error}"
111
+ )
112
+ raise e
113
+ finally:
114
+ if not stream:
115
+ try:
116
+ await ollama_client._client.aclose()
117
+ logger.debug(
118
+ "Successfully closed Ollama client for non-streaming response"
119
+ )
120
+ except Exception as close_error:
121
+ logger.warning(
122
+ f"Failed to close Ollama client in finally block: {close_error}"
123
+ )
124
 
125
 
126
  async def ollama_model_complete(
 
139
  )
140
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
143
  api_key = kwargs.pop("api_key", None)
144
  headers = {
 
146
  "User-Agent": f"LightRAG/{__api_version__}",
147
  }
148
  if api_key:
149
+ headers["Authorization"] = f"Bearer {api_key}"
150
+
151
+ host = kwargs.pop("host", None)
152
+ timeout = kwargs.pop("timeout", None) or 90 # Default time out 90s
153
+
154
+ ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
155
+
156
+ try:
157
+ data = await ollama_client.embed(model=embed_model, input=texts)
158
+ return np.array(data["embeddings"])
159
+ except Exception as e:
160
+ logger.error(f"Error in ollama_embed: {str(e)}")
161
+ try:
162
+ await ollama_client._client.aclose()
163
+ logger.debug("Successfully closed Ollama client after exception in embed")
164
+ except Exception as close_error:
165
+ logger.warning(
166
+ f"Failed to close Ollama client after exception in embed: {close_error}"
167
+ )
168
+ raise e
169
+ finally:
170
+ try:
171
+ await ollama_client._client.aclose()
172
+ logger.debug("Successfully closed Ollama client after embed")
173
+ except Exception as close_error:
174
+ logger.warning(f"Failed to close Ollama client after embed: {close_error}")