João Galego commited on
Commit
ee795f8
·
1 Parent(s): 3d203c4

Added support for Amazon Bedrock models

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__
2
+ *.egg-info
3
+ dickens/
4
+ book.txt
examples/lightrag_bedrock_demo.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LightRAG meets Amazon Bedrock ⛰️
3
+ """
4
+
5
+ import os
6
+
7
+ from lightrag import LightRAG, QueryParam
8
+ from lightrag.llm import bedrock_complete, bedrock_embedding
9
+ from lightrag.utils import EmbeddingFunc
10
+
11
+ WORKING_DIR = "./dickens"
12
+
13
+ if not os.path.exists(WORKING_DIR):
14
+ os.mkdir(WORKING_DIR)
15
+
16
+ rag = LightRAG(
17
+ working_dir=WORKING_DIR,
18
+ llm_model_func=bedrock_complete,
19
+ llm_model_name="anthropic.claude-3-haiku-20240307-v1:0",
20
+ node2vec_params = {
21
+ 'dimensions': 1024,
22
+ 'num_walks': 10,
23
+ 'walk_length': 40,
24
+ 'window_size': 2,
25
+ 'iterations': 3,
26
+ 'random_seed': 3
27
+ },
28
+ embedding_func=EmbeddingFunc(
29
+ embedding_dim=1024,
30
+ max_token_size=8192,
31
+ func=lambda texts: bedrock_embedding(texts)
32
+ )
33
+ )
34
+
35
+ with open("./book.txt") as f:
36
+ rag.insert(f.read())
37
+
38
+ # Naive search
39
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
40
+
41
+ # Local search
42
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
43
+
44
+ # Global search
45
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
46
+
47
+ # Hybrid search
48
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
lightrag/llm.py CHANGED
@@ -1,4 +1,6 @@
1
  import os
 
 
2
  import numpy as np
3
  import ollama
4
  from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
@@ -48,6 +50,54 @@ async def openai_complete_if_cache(
48
  )
49
  return response.choices[0].message.content
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  async def hf_model_if_cache(
52
  model, prompt, system_prompt=None, history_messages=[], **kwargs
53
  ) -> str:
@@ -145,6 +195,19 @@ async def gpt_4o_mini_complete(
145
  **kwargs,
146
  )
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  async def hf_model_complete(
149
  prompt, system_prompt=None, history_messages=[], **kwargs
150
  ) -> str:
@@ -186,6 +249,71 @@ async def openai_embedding(texts: list[str], model: str = "text-embedding-3-smal
186
  return np.array([dp.embedding for dp in response.data])
187
 
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
190
  input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
191
  with torch.no_grad():
 
1
  import os
2
+ import json
3
+ import aioboto3
4
  import numpy as np
5
  import ollama
6
  from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
 
50
  )
51
  return response.choices[0].message.content
52
 
53
+ @retry(
54
+ stop=stop_after_attempt(3),
55
+ wait=wait_exponential(multiplier=1, min=4, max=10),
56
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
57
+ )
58
+ async def bedrock_complete_if_cache(
59
+ model, prompt, system_prompt=None, history_messages=[], base_url=None,
60
+ aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs
61
+ ) -> str:
62
+ os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
63
+ os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
64
+ os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
65
+
66
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
67
+
68
+ messages = []
69
+ messages.extend(history_messages)
70
+ messages.append({'role': "user", 'content': [{'text': prompt}]})
71
+
72
+ args = {
73
+ 'modelId': model,
74
+ 'messages': messages
75
+ }
76
+
77
+ if system_prompt:
78
+ args['system'] = [{'text': system_prompt}]
79
+
80
+ if hashing_kv is not None:
81
+ args_hash = compute_args_hash(model, messages)
82
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
83
+ if if_cache_return is not None:
84
+ return if_cache_return["return"]
85
+
86
+ session = aioboto3.Session()
87
+ async with session.client("bedrock-runtime") as bedrock_async_client:
88
+
89
+ response = await bedrock_async_client.converse(**args, **kwargs)
90
+
91
+ if hashing_kv is not None:
92
+ await hashing_kv.upsert({
93
+ args_hash: {
94
+ 'return': response['output']['message']['content'][0]['text'],
95
+ 'model': model
96
+ }
97
+ })
98
+
99
+ return response['output']['message']['content'][0]['text']
100
+
101
  async def hf_model_if_cache(
102
  model, prompt, system_prompt=None, history_messages=[], **kwargs
103
  ) -> str:
 
195
  **kwargs,
196
  )
197
 
198
+
199
+ async def bedrock_complete(
200
+ prompt, system_prompt=None, history_messages=[], **kwargs
201
+ ) -> str:
202
+ return await bedrock_complete_if_cache(
203
+ "anthropic.claude-3-sonnet-20240229-v1:0",
204
+ prompt,
205
+ system_prompt=system_prompt,
206
+ history_messages=history_messages,
207
+ **kwargs,
208
+ )
209
+
210
+
211
  async def hf_model_complete(
212
  prompt, system_prompt=None, history_messages=[], **kwargs
213
  ) -> str:
 
249
  return np.array([dp.embedding for dp in response.data])
250
 
251
 
252
+ # @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
253
+ # @retry(
254
+ # stop=stop_after_attempt(3),
255
+ # wait=wait_exponential(multiplier=1, min=4, max=10),
256
+ # retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), # TODO: fix exceptions
257
+ # )
258
+ async def bedrock_embedding(
259
+ texts: list[str], model: str = "amazon.titan-embed-text-v2:0",
260
+ aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None) -> np.ndarray:
261
+ os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
262
+ os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
263
+ os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
264
+
265
+ session = aioboto3.Session()
266
+ async with session.client("bedrock-runtime") as bedrock_async_client:
267
+
268
+ if (model_provider := model.split(".")[0]) == "amazon":
269
+ embed_texts = []
270
+ for text in texts:
271
+ if "v2" in model:
272
+ body = json.dumps({
273
+ 'inputText': text,
274
+ # 'dimensions': embedding_dim,
275
+ 'embeddingTypes': ["float"]
276
+ })
277
+ elif "v1" in model:
278
+ body = json.dumps({
279
+ 'inputText': text
280
+ })
281
+ else:
282
+ raise ValueError(f"Model {model} is not supported!")
283
+
284
+ response = await bedrock_async_client.invoke_model(
285
+ modelId=model,
286
+ body=body,
287
+ accept="application/json",
288
+ contentType="application/json"
289
+ )
290
+
291
+ response_body = await response.get('body').json()
292
+
293
+ embed_texts.append(response_body['embedding'])
294
+ elif model_provider == "cohere":
295
+ body = json.dumps({
296
+ 'texts': texts,
297
+ 'input_type': "search_document",
298
+ 'truncate': "NONE"
299
+ })
300
+
301
+ response = await bedrock_async_client.invoke_model(
302
+ model=model,
303
+ body=body,
304
+ accept="application/json",
305
+ contentType="application/json"
306
+ )
307
+
308
+ response_body = json.loads(response.get('body').read())
309
+
310
+ embed_texts = response_body['embeddings']
311
+ else:
312
+ raise ValueError(f"Model provider '{model_provider}' is not supported!")
313
+
314
+ return np.array(embed_texts)
315
+
316
+
317
  async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
318
  input_ids = tokenizer(texts, return_tensors='pt', padding=True, truncation=True).input_ids
319
  with torch.no_grad():
requirements.txt CHANGED
@@ -1,3 +1,4 @@
 
1
  openai
2
  tiktoken
3
  networkx
 
1
+ aioboto3
2
  openai
3
  tiktoken
4
  networkx