LarFii commited on
Commit
251e443
·
1 Parent(s): 99496e0

ollama test

Browse files
examples/lightrag_ollama_demo.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm import ollama_model_complete, ollama_embedding
5
+ from lightrag.utils import EmbeddingFunc
6
+
7
+ WORKING_DIR = "./dickens"
8
+
9
+ if not os.path.exists(WORKING_DIR):
10
+ os.mkdir(WORKING_DIR)
11
+
12
+ rag = LightRAG(
13
+ working_dir=WORKING_DIR,
14
+ llm_model_func=ollama_model_complete,
15
+ llm_model_name='your_model_name',
16
+ embedding_func=EmbeddingFunc(
17
+ embedding_dim=768,
18
+ max_token_size=8192,
19
+ func=lambda texts: ollama_embedding(
20
+ texts,
21
+ embed_model="nomic-embed-text"
22
+ )
23
+ ),
24
+ )
25
+
26
+
27
+ with open("./book.txt") as f:
28
+ rag.insert(f.read())
29
+
30
+ # Perform naive search
31
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
32
+
33
+ # Perform local search
34
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
35
+
36
+ # Perform global search
37
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
38
+
39
+ # Perform hybrid search
40
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
lightrag/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .lightrag import LightRAG, QueryParam
2
 
3
- __version__ = "0.0.5"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
 
1
  from .lightrag import LightRAG, QueryParam
2
 
3
+ __version__ = "0.0.6"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
lightrag/lightrag.py CHANGED
@@ -6,7 +6,7 @@ from functools import partial
6
  from typing import Type, cast, Any
7
  from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
8
 
9
- from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding,hf_model_complete,hf_embedding
10
  from .operate import (
11
  chunking_by_token_size,
12
  extract_entities,
 
6
  from typing import Type, cast, Any
7
  from transformers import AutoModel,AutoTokenizer, AutoModelForCausalLM
8
 
9
+ from .llm import gpt_4o_complete, gpt_4o_mini_complete, openai_embedding, hf_model_complete, hf_embedding
10
  from .operate import (
11
  chunking_by_token_size,
12
  extract_entities,
lightrag/llm.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import numpy as np
 
3
  from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
4
  from tenacity import (
5
  retry,
@@ -92,6 +93,34 @@ async def hf_model_if_cache(
92
  )
93
  return response_text
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  async def gpt_4o_complete(
97
  prompt, system_prompt=None, history_messages=[], **kwargs
@@ -116,8 +145,6 @@ async def gpt_4o_mini_complete(
116
  **kwargs,
117
  )
118
 
119
-
120
-
121
  async def hf_model_complete(
122
  prompt, system_prompt=None, history_messages=[], **kwargs
123
  ) -> str:
@@ -130,6 +157,18 @@ async def hf_model_complete(
130
  **kwargs,
131
  )
132
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
134
  @retry(
135
  stop=stop_after_attempt(3),
@@ -154,6 +193,13 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
154
  embeddings = outputs.last_hidden_state.mean(dim=1)
155
  return embeddings.detach().numpy()
156
 
 
 
 
 
 
 
 
157
 
158
  if __name__ == "__main__":
159
  import asyncio
 
1
  import os
2
  import numpy as np
3
+ import ollama
4
  from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
5
  from tenacity import (
6
  retry,
 
93
  )
94
  return response_text
95
 
96
+ async def ollama_model_if_cache(
97
+ model, prompt, system_prompt=None, history_messages=[], **kwargs
98
+ ) -> str:
99
+ kwargs.pop("max_tokens", None)
100
+ kwargs.pop("response_format", None)
101
+
102
+ ollama_client = ollama.AsyncClient()
103
+ messages = []
104
+ if system_prompt:
105
+ messages.append({"role": "system", "content": system_prompt})
106
+
107
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
108
+ messages.extend(history_messages)
109
+ messages.append({"role": "user", "content": prompt})
110
+ if hashing_kv is not None:
111
+ args_hash = compute_args_hash(model, messages)
112
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
113
+ if if_cache_return is not None:
114
+ return if_cache_return["return"]
115
+
116
+ response = await ollama_client.chat(model=model, messages=messages, **kwargs)
117
+
118
+ result = response["message"]["content"]
119
+
120
+ if hashing_kv is not None:
121
+ await hashing_kv.upsert({args_hash: {"return": result, "model": model}})
122
+
123
+ return result
124
 
125
  async def gpt_4o_complete(
126
  prompt, system_prompt=None, history_messages=[], **kwargs
 
145
  **kwargs,
146
  )
147
 
 
 
148
  async def hf_model_complete(
149
  prompt, system_prompt=None, history_messages=[], **kwargs
150
  ) -> str:
 
157
  **kwargs,
158
  )
159
 
160
+ async def ollama_model_complete(
161
+ prompt, system_prompt=None, history_messages=[], **kwargs
162
+ ) -> str:
163
+ model_name = kwargs['hashing_kv'].global_config['llm_model_name']
164
+ return await ollama_model_if_cache(
165
+ model_name,
166
+ prompt,
167
+ system_prompt=system_prompt,
168
+ history_messages=history_messages,
169
+ **kwargs,
170
+ )
171
+
172
  @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
173
  @retry(
174
  stop=stop_after_attempt(3),
 
193
  embeddings = outputs.last_hidden_state.mean(dim=1)
194
  return embeddings.detach().numpy()
195
 
196
+ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
197
+ embed_text = []
198
+ for text in texts:
199
+ data = ollama.embeddings(model=embed_model, prompt=text)
200
+ embed_text.append(data["embedding"])
201
+
202
+ return embed_text
203
 
204
  if __name__ == "__main__":
205
  import asyncio
requirements.txt CHANGED
@@ -6,3 +6,6 @@ nano-vectordb
6
  hnswlib
7
  xxhash
8
  tenacity
 
 
 
 
6
  hnswlib
7
  xxhash
8
  tenacity
9
+ transformers
10
+ torch
11
+ ollama
setup.py CHANGED
@@ -1,6 +1,6 @@
1
  import setuptools
2
 
3
- with open("README.md", "r") as fh:
4
  long_description = fh.read()
5
 
6
 
 
1
  import setuptools
2
 
3
+ with open("README.md", "r", encoding="utf-8") as fh:
4
  long_description = fh.read()
5
 
6