gzdaniel commited on
Commit
f2f0d26
·
1 Parent(s): 2cab749

Optimize Ollama LLM driver

Browse files
Files changed (3) hide show
  1. README-zh.md +1 -1
  2. README.md +1 -1
  3. lightrag/llm/ollama.py +78 -44
README-zh.md CHANGED
@@ -415,7 +415,7 @@ rag = LightRAG(
415
  embedding_func=EmbeddingFunc(
416
  embedding_dim=768,
417
  max_token_size=8192,
418
- func=lambda texts: ollama_embedding(
419
  texts,
420
  embed_model="nomic-embed-text"
421
  )
 
415
  embedding_func=EmbeddingFunc(
416
  embedding_dim=768,
417
  max_token_size=8192,
418
+ func=lambda texts: ollama_embed(
419
  texts,
420
  embed_model="nomic-embed-text"
421
  )
README.md CHANGED
@@ -447,7 +447,7 @@ rag = LightRAG(
447
  embedding_func=EmbeddingFunc(
448
  embedding_dim=768,
449
  max_token_size=8192,
450
- func=lambda texts: ollama_embedding(
451
  texts,
452
  embed_model="nomic-embed-text"
453
  )
 
447
  embedding_func=EmbeddingFunc(
448
  embedding_dim=768,
449
  max_token_size=8192,
450
+ func=lambda texts: ollama_embed(
451
  texts,
452
  embed_model="nomic-embed-text"
453
  )
lightrag/llm/ollama.py CHANGED
@@ -31,6 +31,7 @@ from lightrag.api import __api_version__
31
 
32
  import numpy as np
33
  from typing import Union
 
34
 
35
 
36
  @retry(
@@ -52,7 +53,7 @@ async def _ollama_model_if_cache(
52
  kwargs.pop("max_tokens", None)
53
  # kwargs.pop("response_format", None) # allow json
54
  host = kwargs.pop("host", None)
55
- timeout = kwargs.pop("timeout", None)
56
  kwargs.pop("hashing_kv", None)
57
  api_key = kwargs.pop("api_key", None)
58
  headers = {
@@ -61,32 +62,59 @@ async def _ollama_model_if_cache(
61
  }
62
  if api_key:
63
  headers["Authorization"] = f"Bearer {api_key}"
 
64
  ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
65
- messages = []
66
- if system_prompt:
67
- messages.append({"role": "system", "content": system_prompt})
68
- messages.extend(history_messages)
69
- messages.append({"role": "user", "content": prompt})
70
-
71
- response = await ollama_client.chat(model=model, messages=messages, **kwargs)
72
- if stream:
73
- """cannot cache stream response and process reasoning"""
74
-
75
- async def inner():
76
- async for chunk in response:
77
- yield chunk["message"]["content"]
78
-
79
- return inner()
80
- else:
81
- model_response = response["message"]["content"]
82
-
83
- """
84
- If the model also wraps its thoughts in a specific tag,
85
- this information is not needed for the final
86
- response and can simply be trimmed.
87
- """
88
-
89
- return model_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
 
92
  async def ollama_model_complete(
@@ -105,19 +133,6 @@ async def ollama_model_complete(
105
  )
106
 
107
 
108
- async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
109
- """
110
- Deprecated in favor of `embed`.
111
- """
112
- embed_text = []
113
- ollama_client = ollama.Client(**kwargs)
114
- for text in texts:
115
- data = ollama_client.embeddings(model=embed_model, prompt=text)
116
- embed_text.append(data["embedding"])
117
-
118
- return embed_text
119
-
120
-
121
  async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
122
  api_key = kwargs.pop("api_key", None)
123
  headers = {
@@ -125,8 +140,27 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
125
  "User-Agent": f"LightRAG/{__api_version__}",
126
  }
127
  if api_key:
128
- headers["Authorization"] = api_key
129
- kwargs["headers"] = headers
130
- ollama_client = ollama.Client(**kwargs)
131
- data = ollama_client.embed(model=embed_model, input=texts)
132
- return np.array(data["embeddings"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  import numpy as np
33
  from typing import Union
34
+ from lightrag.utils import logger
35
 
36
 
37
  @retry(
 
53
  kwargs.pop("max_tokens", None)
54
  # kwargs.pop("response_format", None) # allow json
55
  host = kwargs.pop("host", None)
56
+ timeout = kwargs.pop("timeout", None) or 300 # Default timeout 300s
57
  kwargs.pop("hashing_kv", None)
58
  api_key = kwargs.pop("api_key", None)
59
  headers = {
 
62
  }
63
  if api_key:
64
  headers["Authorization"] = f"Bearer {api_key}"
65
+
66
  ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
67
+
68
+ try:
69
+ messages = []
70
+ if system_prompt:
71
+ messages.append({"role": "system", "content": system_prompt})
72
+ messages.extend(history_messages)
73
+ messages.append({"role": "user", "content": prompt})
74
+
75
+ response = await ollama_client.chat(model=model, messages=messages, **kwargs)
76
+ if stream:
77
+ """cannot cache stream response and process reasoning"""
78
+
79
+ async def inner():
80
+ try:
81
+ async for chunk in response:
82
+ yield chunk["message"]["content"]
83
+ except Exception as e:
84
+ logger.error(f"Error in stream response: {str(e)}")
85
+ raise
86
+ finally:
87
+ try:
88
+ await ollama_client._client.aclose()
89
+ logger.debug("Successfully closed Ollama client for streaming")
90
+ except Exception as close_error:
91
+ logger.warning(f"Failed to close Ollama client: {close_error}")
92
+
93
+ return inner()
94
+ else:
95
+ model_response = response["message"]["content"]
96
+
97
+ """
98
+ If the model also wraps its thoughts in a specific tag,
99
+ this information is not needed for the final
100
+ response and can simply be trimmed.
101
+ """
102
+
103
+ return model_response
104
+ except Exception as e:
105
+ try:
106
+ await ollama_client._client.aclose()
107
+ logger.debug("Successfully closed Ollama client after exception")
108
+ except Exception as close_error:
109
+ logger.warning(f"Failed to close Ollama client after exception: {close_error}")
110
+ raise e
111
+ finally:
112
+ if not stream:
113
+ try:
114
+ await ollama_client._client.aclose()
115
+ logger.debug("Successfully closed Ollama client for non-streaming response")
116
+ except Exception as close_error:
117
+ logger.warning(f"Failed to close Ollama client in finally block: {close_error}")
118
 
119
 
120
  async def ollama_model_complete(
 
133
  )
134
 
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
137
  api_key = kwargs.pop("api_key", None)
138
  headers = {
 
140
  "User-Agent": f"LightRAG/{__api_version__}",
141
  }
142
  if api_key:
143
+ headers["Authorization"] = f"Bearer {api_key}"
144
+
145
+ host = kwargs.pop("host", None)
146
+ timeout = kwargs.pop("timeout", None) or 90 # Default time out 90s
147
+
148
+ ollama_client = ollama.AsyncClient(host=host, timeout=timeout, headers=headers)
149
+
150
+ try:
151
+ data = await ollama_client.embed(model=embed_model, input=texts)
152
+ return np.array(data["embeddings"])
153
+ except Exception as e:
154
+ logger.error(f"Error in ollama_embed: {str(e)}")
155
+ try:
156
+ await ollama_client._client.aclose()
157
+ logger.debug("Successfully closed Ollama client after exception in embed")
158
+ except Exception as close_error:
159
+ logger.warning(f"Failed to close Ollama client after exception in embed: {close_error}")
160
+ raise e
161
+ finally:
162
+ try:
163
+ await ollama_client._client.aclose()
164
+ logger.debug("Successfully closed Ollama client after embed")
165
+ except Exception as close_error:
166
+ logger.warning(f"Failed to close Ollama client after embed: {close_error}")