João Galego commited on
Commit
51d8af2
·
1 Parent(s): ee795f8

Fixed retry strategy, message history and inference params; Cleaned up Bedrock example

Browse files
Files changed (2) hide show
  1. examples/lightrag_bedrock_demo.py +16 -23
  2. lightrag/llm.py +39 -9
examples/lightrag_bedrock_demo.py CHANGED
@@ -3,46 +3,39 @@ LightRAG meets Amazon Bedrock ⛰️
3
  """
4
 
5
  import os
 
6
 
7
  from lightrag import LightRAG, QueryParam
8
  from lightrag.llm import bedrock_complete, bedrock_embedding
9
  from lightrag.utils import EmbeddingFunc
10
 
11
- WORKING_DIR = "./dickens"
12
 
 
13
  if not os.path.exists(WORKING_DIR):
14
  os.mkdir(WORKING_DIR)
15
 
16
  rag = LightRAG(
17
  working_dir=WORKING_DIR,
18
  llm_model_func=bedrock_complete,
19
- llm_model_name="anthropic.claude-3-haiku-20240307-v1:0",
20
- node2vec_params = {
21
- 'dimensions': 1024,
22
- 'num_walks': 10,
23
- 'walk_length': 40,
24
- 'window_size': 2,
25
- 'iterations': 3,
26
- 'random_seed': 3
27
- },
28
  embedding_func=EmbeddingFunc(
29
  embedding_dim=1024,
30
  max_token_size=8192,
31
- func=lambda texts: bedrock_embedding(texts)
32
  )
33
  )
34
 
35
- with open("./book.txt") as f:
36
  rag.insert(f.read())
37
 
38
- # Naive search
39
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
40
-
41
- # Local search
42
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
43
-
44
- # Global search
45
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
46
-
47
- # Hybrid search
48
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
 
3
  """
4
 
5
  import os
6
+ import logging
7
 
8
  from lightrag import LightRAG, QueryParam
9
  from lightrag.llm import bedrock_complete, bedrock_embedding
10
  from lightrag.utils import EmbeddingFunc
11
 
12
+ logging.getLogger("aiobotocore").setLevel(logging.WARNING)
13
 
14
+ WORKING_DIR = "./dickens"
15
  if not os.path.exists(WORKING_DIR):
16
  os.mkdir(WORKING_DIR)
17
 
18
  rag = LightRAG(
19
  working_dir=WORKING_DIR,
20
  llm_model_func=bedrock_complete,
21
+ llm_model_name="Anthropic Claude 3 Haiku // Amazon Bedrock",
 
 
 
 
 
 
 
 
22
  embedding_func=EmbeddingFunc(
23
  embedding_dim=1024,
24
  max_token_size=8192,
25
+ func=bedrock_embedding
26
  )
27
  )
28
 
29
+ with open("./book.txt", 'r', encoding='utf-8') as f:
30
  rag.insert(f.read())
31
 
32
+ for mode in ["naive", "local", "global", "hybrid"]:
33
+ print("\n+-" + "-" * len(mode) + "-+")
34
+ print(f"| {mode.capitalize()} |")
35
+ print("+-" + "-" * len(mode) + "-+\n")
36
+ print(
37
+ rag.query(
38
+ "What are the top themes in this story?",
39
+ param=QueryParam(mode=mode)
40
+ )
41
+ )
 
lightrag/llm.py CHANGED
@@ -1,6 +1,9 @@
1
  import os
 
2
  import json
 
3
  import aioboto3
 
4
  import numpy as np
5
  import ollama
6
  from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
@@ -50,43 +53,70 @@ async def openai_complete_if_cache(
50
  )
51
  return response.choices[0].message.content
52
 
 
 
 
 
 
53
  @retry(
54
- stop=stop_after_attempt(3),
55
- wait=wait_exponential(multiplier=1, min=4, max=10),
56
- retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
57
  )
58
  async def bedrock_complete_if_cache(
59
- model, prompt, system_prompt=None, history_messages=[], base_url=None,
60
  aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs
61
  ) -> str:
62
  os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
63
  os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
64
  os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
65
 
66
- hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
67
-
68
  messages = []
69
- messages.extend(history_messages)
 
 
 
 
 
70
  messages.append({'role': "user", 'content': [{'text': prompt}]})
71
 
 
72
  args = {
73
  'modelId': model,
74
  'messages': messages
75
  }
76
 
 
77
  if system_prompt:
78
  args['system'] = [{'text': system_prompt}]
79
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  if hashing_kv is not None:
81
  args_hash = compute_args_hash(model, messages)
82
  if_cache_return = await hashing_kv.get_by_id(args_hash)
83
  if if_cache_return is not None:
84
  return if_cache_return["return"]
85
 
 
86
  session = aioboto3.Session()
87
  async with session.client("bedrock-runtime") as bedrock_async_client:
88
 
89
- response = await bedrock_async_client.converse(**args, **kwargs)
 
 
 
90
 
91
  if hashing_kv is not None:
92
  await hashing_kv.upsert({
@@ -200,7 +230,7 @@ async def bedrock_complete(
200
  prompt, system_prompt=None, history_messages=[], **kwargs
201
  ) -> str:
202
  return await bedrock_complete_if_cache(
203
- "anthropic.claude-3-sonnet-20240229-v1:0",
204
  prompt,
205
  system_prompt=system_prompt,
206
  history_messages=history_messages,
 
1
  import os
2
+ import copy
3
  import json
4
+ import botocore
5
  import aioboto3
6
+ import botocore.errorfactory
7
  import numpy as np
8
  import ollama
9
  from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout
 
53
  )
54
  return response.choices[0].message.content
55
 
56
+
57
+ class BedrockError(Exception):
58
+ """Generic error for issues related to Amazon Bedrock"""
59
+
60
+
61
  @retry(
62
+ stop=stop_after_attempt(5),
63
+ wait=wait_exponential(multiplier=1, max=60),
64
+ retry=retry_if_exception_type((BedrockError)),
65
  )
66
  async def bedrock_complete_if_cache(
67
+ model, prompt, system_prompt=None, history_messages=[],
68
  aws_access_key_id=None, aws_secret_access_key=None, aws_session_token=None, **kwargs
69
  ) -> str:
70
  os.environ['AWS_ACCESS_KEY_ID'] = os.environ.get('AWS_ACCESS_KEY_ID', aws_access_key_id)
71
  os.environ['AWS_SECRET_ACCESS_KEY'] = os.environ.get('AWS_SECRET_ACCESS_KEY', aws_secret_access_key)
72
  os.environ['AWS_SESSION_TOKEN'] = os.environ.get('AWS_SESSION_TOKEN', aws_session_token)
73
 
74
+ # Fix message history format
 
75
  messages = []
76
+ for history_message in history_messages:
77
+ message = copy.copy(history_message)
78
+ message['content'] = [{'text': message['content']}]
79
+ messages.append(message)
80
+
81
+ # Add user prompt
82
  messages.append({'role': "user", 'content': [{'text': prompt}]})
83
 
84
+ # Initialize Converse API arguments
85
  args = {
86
  'modelId': model,
87
  'messages': messages
88
  }
89
 
90
+ # Define system prompt
91
  if system_prompt:
92
  args['system'] = [{'text': system_prompt}]
93
 
94
+ # Map and set up inference parameters
95
+ inference_params_map = {
96
+ 'max_tokens': "maxTokens",
97
+ 'top_p': "topP",
98
+ 'stop_sequences': "stopSequences"
99
+ }
100
+ if (inference_params := list(set(kwargs) & set(['max_tokens', 'temperature', 'top_p', 'stop_sequences']))):
101
+ args['inferenceConfig'] = {}
102
+ for param in inference_params:
103
+ args['inferenceConfig'][inference_params_map.get(param, param)] = kwargs.pop(param)
104
+
105
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
106
  if hashing_kv is not None:
107
  args_hash = compute_args_hash(model, messages)
108
  if_cache_return = await hashing_kv.get_by_id(args_hash)
109
  if if_cache_return is not None:
110
  return if_cache_return["return"]
111
 
112
+ # Call model via Converse API
113
  session = aioboto3.Session()
114
  async with session.client("bedrock-runtime") as bedrock_async_client:
115
 
116
+ try:
117
+ response = await bedrock_async_client.converse(**args, **kwargs)
118
+ except Exception as e:
119
+ raise BedrockError(e)
120
 
121
  if hashing_kv is not None:
122
  await hashing_kv.upsert({
 
230
  prompt, system_prompt=None, history_messages=[], **kwargs
231
  ) -> str:
232
  return await bedrock_complete_if_cache(
233
+ "anthropic.claude-3-haiku-20240307-v1:0",
234
  prompt,
235
  system_prompt=system_prompt,
236
  history_messages=history_messages,