gzdaniel commited on
Commit
74a3756
·
1 Parent(s): f2f0d26

Update Ollama sample code

Browse files
Files changed (1) hide show
  1. examples/lightrag_ollama_demo.py +173 -58
examples/lightrag_ollama_demo.py CHANGED
@@ -1,19 +1,84 @@
1
  import asyncio
2
- import nest_asyncio
3
-
4
  import os
5
  import inspect
6
  import logging
 
7
  from lightrag import LightRAG, QueryParam
8
  from lightrag.llm.ollama import ollama_model_complete, ollama_embed
9
- from lightrag.utils import EmbeddingFunc
10
  from lightrag.kg.shared_storage import initialize_pipeline_status
11
 
12
- nest_asyncio.apply()
 
 
13
 
14
  WORKING_DIR = "./dickens"
15
 
16
- logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  if not os.path.exists(WORKING_DIR):
19
  os.mkdir(WORKING_DIR)
@@ -23,18 +88,20 @@ async def initialize_rag():
23
  rag = LightRAG(
24
  working_dir=WORKING_DIR,
25
  llm_model_func=ollama_model_complete,
26
- llm_model_name="gemma2:2b",
27
- llm_model_max_async=4,
28
- llm_model_max_token_size=32768,
29
  llm_model_kwargs={
30
- "host": "http://localhost:11434",
31
- "options": {"num_ctx": 32768},
 
32
  },
33
  embedding_func=EmbeddingFunc(
34
- embedding_dim=768,
35
- max_token_size=8192,
36
  func=lambda texts: ollama_embed(
37
- texts, embed_model="nomic-embed-text", host="http://localhost:11434"
 
 
38
  ),
39
  ),
40
  )
@@ -50,54 +117,102 @@ async def print_stream(stream):
50
  print(chunk, end="", flush=True)
51
 
52
 
53
- def main():
54
- # Initialize RAG instance
55
- rag = asyncio.run(initialize_rag())
56
-
57
- # Insert example text
58
- with open("./book.txt", "r", encoding="utf-8") as f:
59
- rag.insert(f.read())
60
-
61
- # Test different query modes
62
- print("\nNaive Search:")
63
- print(
64
- rag.query(
65
- "What are the top themes in this story?", param=QueryParam(mode="naive")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  )
67
- )
68
-
69
- print("\nLocal Search:")
70
- print(
71
- rag.query(
72
- "What are the top themes in this story?", param=QueryParam(mode="local")
 
 
 
 
 
 
73
  )
74
- )
75
-
76
- print("\nGlobal Search:")
77
- print(
78
- rag.query(
79
- "What are the top themes in this story?", param=QueryParam(mode="global")
 
 
 
 
 
 
80
  )
81
- )
82
-
83
- print("\nHybrid Search:")
84
- print(
85
- rag.query(
86
- "What are the top themes in this story?", param=QueryParam(mode="hybrid")
 
 
 
 
 
 
87
  )
88
- )
89
-
90
- # stream response
91
- resp = rag.query(
92
- "What are the top themes in this story?",
93
- param=QueryParam(mode="hybrid", stream=True),
94
- )
95
-
96
- if inspect.isasyncgen(resp):
97
- asyncio.run(print_stream(resp))
98
- else:
99
- print(resp)
100
-
101
 
102
  if __name__ == "__main__":
103
- main()
 
 
 
 
1
  import asyncio
 
 
2
  import os
3
  import inspect
4
  import logging
5
+ import logging.config
6
  from lightrag import LightRAG, QueryParam
7
  from lightrag.llm.ollama import ollama_model_complete, ollama_embed
8
+ from lightrag.utils import EmbeddingFunc, logger, set_verbose_debug
9
  from lightrag.kg.shared_storage import initialize_pipeline_status
10
 
11
+ from dotenv import load_dotenv
12
+
13
+ load_dotenv(dotenv_path=".env", override=False)
14
 
15
  WORKING_DIR = "./dickens"
16
 
17
+
18
+ def configure_logging():
19
+ """Configure logging for the application"""
20
+
21
+ # Reset any existing handlers to ensure clean configuration
22
+ for logger_name in ["uvicorn", "uvicorn.access", "uvicorn.error", "lightrag"]:
23
+ logger_instance = logging.getLogger(logger_name)
24
+ logger_instance.handlers = []
25
+ logger_instance.filters = []
26
+
27
+ # Get log directory path from environment variable or use current directory
28
+ log_dir = os.getenv("LOG_DIR", os.getcwd())
29
+ log_file_path = os.path.abspath(
30
+ os.path.join(log_dir, "lightrag_ollama_demo.log")
31
+ )
32
+
33
+ print(f"\nLightRAG compatible demo log file: {log_file_path}\n")
34
+ os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
35
+
36
+ # Get log file max size and backup count from environment variables
37
+ log_max_bytes = int(os.getenv("LOG_MAX_BYTES", 10485760)) # Default 10MB
38
+ log_backup_count = int(os.getenv("LOG_BACKUP_COUNT", 5)) # Default 5 backups
39
+
40
+ logging.config.dictConfig(
41
+ {
42
+ "version": 1,
43
+ "disable_existing_loggers": False,
44
+ "formatters": {
45
+ "default": {
46
+ "format": "%(levelname)s: %(message)s",
47
+ },
48
+ "detailed": {
49
+ "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
50
+ },
51
+ },
52
+ "handlers": {
53
+ "console": {
54
+ "formatter": "default",
55
+ "class": "logging.StreamHandler",
56
+ "stream": "ext://sys.stderr",
57
+ },
58
+ "file": {
59
+ "formatter": "detailed",
60
+ "class": "logging.handlers.RotatingFileHandler",
61
+ "filename": log_file_path,
62
+ "maxBytes": log_max_bytes,
63
+ "backupCount": log_backup_count,
64
+ "encoding": "utf-8",
65
+ },
66
+ },
67
+ "loggers": {
68
+ "lightrag": {
69
+ "handlers": ["console", "file"],
70
+ "level": "INFO",
71
+ "propagate": False,
72
+ },
73
+ },
74
+ }
75
+ )
76
+
77
+ # Set the logger level to INFO
78
+ logger.setLevel(logging.INFO)
79
+ # Enable verbose debug if needed
80
+ set_verbose_debug(os.getenv("VERBOSE_DEBUG", "false").lower() == "true")
81
+
82
 
83
  if not os.path.exists(WORKING_DIR):
84
  os.mkdir(WORKING_DIR)
 
88
  rag = LightRAG(
89
  working_dir=WORKING_DIR,
90
  llm_model_func=ollama_model_complete,
91
+ llm_model_name=os.getenv("LLM_MODEL", "qwen2.5-coder:7b"),
92
+ llm_model_max_token_size=8192,
 
93
  llm_model_kwargs={
94
+ "host": os.getenv("LLM_BINDING_HOST", "http://localhost:11434"),
95
+ "options": {"num_ctx": 8192},
96
+ "timeout": int(os.getenv("TIMEOUT", "300")),
97
  },
98
  embedding_func=EmbeddingFunc(
99
+ embedding_dim=int(os.getenv("EMBEDDING_DIM", "1024")),
100
+ max_token_size=int(os.getenv("MAX_EMBED_TOKENS", "8192")),
101
  func=lambda texts: ollama_embed(
102
+ texts,
103
+ embed_model=os.getenv("EMBEDDING_MODEL", "bge-m3:latest"),
104
+ host=os.getenv("EMBEDDING_BINDING_HOST", "http://localhost:11434"),
105
  ),
106
  ),
107
  )
 
117
  print(chunk, end="", flush=True)
118
 
119
 
120
+ async def main():
121
+ try:
122
+ # Clear old data files
123
+ files_to_delete = [
124
+ "graph_chunk_entity_relation.graphml",
125
+ "kv_store_doc_status.json",
126
+ "kv_store_full_docs.json",
127
+ "kv_store_text_chunks.json",
128
+ "vdb_chunks.json",
129
+ "vdb_entities.json",
130
+ "vdb_relationships.json",
131
+ ]
132
+
133
+ for file in files_to_delete:
134
+ file_path = os.path.join(WORKING_DIR, file)
135
+ if os.path.exists(file_path):
136
+ os.remove(file_path)
137
+ print(f"Deleting old file:: {file_path}")
138
+
139
+ # Initialize RAG instance
140
+ rag = await initialize_rag()
141
+
142
+ # Test embedding function
143
+ test_text = ["This is a test string for embedding."]
144
+ embedding = await rag.embedding_func(test_text)
145
+ embedding_dim = embedding.shape[1]
146
+ print("\n=======================")
147
+ print("Test embedding function")
148
+ print("========================")
149
+ print(f"Test dict: {test_text}")
150
+ print(f"Detected embedding dimension: {embedding_dim}\n\n")
151
+
152
+ with open("./book.txt", "r", encoding="utf-8") as f:
153
+ await rag.ainsert(f.read())
154
+
155
+ # Perform naive search
156
+ print("\n=====================")
157
+ print("Query mode: naive")
158
+ print("=====================")
159
+ resp = await rag.aquery(
160
+ "What are the top themes in this story?",
161
+ param=QueryParam(mode="naive", stream=True),
162
  )
163
+ if inspect.isasyncgen(resp):
164
+ await print_stream(resp)
165
+ else:
166
+ print(resp)
167
+
168
+ # Perform local search
169
+ print("\n=====================")
170
+ print("Query mode: local")
171
+ print("=====================")
172
+ resp = await rag.aquery(
173
+ "What are the top themes in this story?",
174
+ param=QueryParam(mode="local", stream=True),
175
  )
176
+ if inspect.isasyncgen(resp):
177
+ await print_stream(resp)
178
+ else:
179
+ print(resp)
180
+
181
+ # Perform global search
182
+ print("\n=====================")
183
+ print("Query mode: global")
184
+ print("=====================")
185
+ resp = await rag.aquery(
186
+ "What are the top themes in this story?",
187
+ param=QueryParam(mode="global", stream=True),
188
  )
189
+ if inspect.isasyncgen(resp):
190
+ await print_stream(resp)
191
+ else:
192
+ print(resp)
193
+
194
+ # Perform hybrid search
195
+ print("\n=====================")
196
+ print("Query mode: hybrid")
197
+ print("=====================")
198
+ resp = await rag.aquery(
199
+ "What are the top themes in this story?",
200
+ param=QueryParam(mode="hybrid", stream=True),
201
  )
202
+ if inspect.isasyncgen(resp):
203
+ await print_stream(resp)
204
+ else:
205
+ print(resp)
206
+
207
+ except Exception as e:
208
+ print(f"An error occurred: {e}")
209
+ finally:
210
+ if rag:
211
+ await rag.llm_response_cache.index_done_callback()
212
+ await rag.finalize_storages()
 
 
213
 
214
  if __name__ == "__main__":
215
+ # Configure logging before running the main function
216
+ configure_logging()
217
+ asyncio.run(main())
218
+ print("\nDone!")