Sanketh Kumar commited on
Commit
5080e1b
·
1 Parent(s): 97a9bc5

Manually reformatted files

Browse files
.github/workflows/linting.yaml CHANGED
@@ -15,7 +15,7 @@ jobs:
15
  steps:
16
  - name: Checkout code
17
  uses: actions/checkout@v2
18
-
19
  - name: Set up Python
20
  uses: actions/setup-python@v2
21
  with:
@@ -27,4 +27,4 @@ jobs:
27
  pip install pre-commit
28
 
29
  - name: Run pre-commit
30
- run: pre-commit run --all-files
 
15
  steps:
16
  - name: Checkout code
17
  uses: actions/checkout@v2
18
+
19
  - name: Set up Python
20
  uses: actions/setup-python@v2
21
  with:
 
27
  pip install pre-commit
28
 
29
  - name: Run pre-commit
30
+ run: pre-commit run --all-files
.gitignore CHANGED
@@ -4,4 +4,4 @@ dickens/
4
  book.txt
5
  lightrag-dev/
6
  .idea/
7
- dist/
 
4
  book.txt
5
  lightrag-dev/
6
  .idea/
7
+ dist/
README.md CHANGED
@@ -58,8 +58,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
58
 
59
  #########
60
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
61
- # import nest_asyncio
62
- # nest_asyncio.apply()
63
  #########
64
 
65
  WORKING_DIR = "./dickens"
@@ -157,7 +157,7 @@ rag = LightRAG(
157
 
158
  <details>
159
  <summary> Using Ollama Models </summary>
160
-
161
  * If you want to use Ollama models, you only need to set LightRAG as follows:
162
 
163
  ```python
@@ -328,8 +328,8 @@ def main():
328
  SET e.entity_type = node.entity_type,
329
  e.description = node.description,
330
  e.source_id = node.source_id,
331
- e.displayName = node.id
332
- REMOVE e:Entity
333
  WITH e, node
334
  CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
335
  RETURN count(*)
@@ -382,7 +382,7 @@ def main():
382
 
383
  except Exception as e:
384
  print(f"Error occurred: {e}")
385
-
386
  finally:
387
  driver.close()
388
 
 
58
 
59
  #########
60
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
61
+ # import nest_asyncio
62
+ # nest_asyncio.apply()
63
  #########
64
 
65
  WORKING_DIR = "./dickens"
 
157
 
158
  <details>
159
  <summary> Using Ollama Models </summary>
160
+
161
  * If you want to use Ollama models, you only need to set LightRAG as follows:
162
 
163
  ```python
 
328
  SET e.entity_type = node.entity_type,
329
  e.description = node.description,
330
  e.source_id = node.source_id,
331
+ e.displayName = node.id
332
+ REMOVE e:Entity
333
  WITH e, node
334
  CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
335
  RETURN count(*)
 
382
 
383
  except Exception as e:
384
  print(f"Error occurred: {e}")
385
+
386
  finally:
387
  driver.close()
388
 
examples/graph_visual_with_html.py CHANGED
@@ -3,7 +3,7 @@ from pyvis.network import Network
3
  import random
4
 
5
  # Load the GraphML file
6
- G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
7
 
8
  # Create a Pyvis network
9
  net = Network(notebook=True)
@@ -13,7 +13,7 @@ net.from_nx(G)
13
 
14
  # Add colors to nodes
15
  for node in net.nodes:
16
- node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
17
 
18
  # Save and display the network
19
- net.show('knowledge_graph.html')
 
3
  import random
4
 
5
  # Load the GraphML file
6
+ G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml")
7
 
8
  # Create a Pyvis network
9
  net = Network(notebook=True)
 
13
 
14
  # Add colors to nodes
15
  for node in net.nodes:
16
+ node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
17
 
18
  # Save and display the network
19
+ net.show("knowledge_graph.html")
examples/graph_visual_with_neo4j.py CHANGED
@@ -13,6 +13,7 @@ NEO4J_URI = "bolt://localhost:7687"
13
  NEO4J_USERNAME = "neo4j"
14
  NEO4J_PASSWORD = "your_password"
15
 
 
16
  def convert_xml_to_json(xml_path, output_path):
17
  """Converts XML file to JSON and saves the output."""
18
  if not os.path.exists(xml_path):
@@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path):
21
 
22
  json_data = xml_to_json(xml_path)
23
  if json_data:
24
- with open(output_path, 'w', encoding='utf-8') as f:
25
  json.dump(json_data, f, ensure_ascii=False, indent=2)
26
  print(f"JSON file created: {output_path}")
27
  return json_data
@@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path):
29
  print("Failed to create JSON data")
30
  return None
31
 
 
32
  def process_in_batches(tx, query, data, batch_size):
33
  """Process data in batches and execute the given query."""
34
  for i in range(0, len(data), batch_size):
35
- batch = data[i:i + batch_size]
36
  tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
37
 
 
38
  def main():
39
  # Paths
40
- xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
41
- json_file = os.path.join(WORKING_DIR, 'graph_data.json')
42
 
43
  # Convert XML to JSON
44
  json_data = convert_xml_to_json(xml_file, json_file)
@@ -46,8 +49,8 @@ def main():
46
  return
47
 
48
  # Load nodes and edges
49
- nodes = json_data.get('nodes', [])
50
- edges = json_data.get('edges', [])
51
 
52
  # Neo4j queries
53
  create_nodes_query = """
@@ -56,8 +59,8 @@ def main():
56
  SET e.entity_type = node.entity_type,
57
  e.description = node.description,
58
  e.source_id = node.source_id,
59
- e.displayName = node.id
60
- REMOVE e:Entity
61
  WITH e, node
62
  CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
63
  RETURN count(*)
@@ -100,19 +103,24 @@ def main():
100
  # Execute queries in batches
101
  with driver.session() as session:
102
  # Insert nodes in batches
103
- session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
 
 
104
 
105
  # Insert edges in batches
106
- session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
 
 
107
 
108
  # Set displayName and labels
109
  session.run(set_displayname_and_labels_query)
110
 
111
  except Exception as e:
112
  print(f"Error occurred: {e}")
113
-
114
  finally:
115
  driver.close()
116
 
 
117
  if __name__ == "__main__":
118
  main()
 
13
  NEO4J_USERNAME = "neo4j"
14
  NEO4J_PASSWORD = "your_password"
15
 
16
+
17
  def convert_xml_to_json(xml_path, output_path):
18
  """Converts XML file to JSON and saves the output."""
19
  if not os.path.exists(xml_path):
 
22
 
23
  json_data = xml_to_json(xml_path)
24
  if json_data:
25
+ with open(output_path, "w", encoding="utf-8") as f:
26
  json.dump(json_data, f, ensure_ascii=False, indent=2)
27
  print(f"JSON file created: {output_path}")
28
  return json_data
 
30
  print("Failed to create JSON data")
31
  return None
32
 
33
+
34
  def process_in_batches(tx, query, data, batch_size):
35
  """Process data in batches and execute the given query."""
36
  for i in range(0, len(data), batch_size):
37
+ batch = data[i : i + batch_size]
38
  tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
39
 
40
+
41
  def main():
42
  # Paths
43
+ xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
44
+ json_file = os.path.join(WORKING_DIR, "graph_data.json")
45
 
46
  # Convert XML to JSON
47
  json_data = convert_xml_to_json(xml_file, json_file)
 
49
  return
50
 
51
  # Load nodes and edges
52
+ nodes = json_data.get("nodes", [])
53
+ edges = json_data.get("edges", [])
54
 
55
  # Neo4j queries
56
  create_nodes_query = """
 
59
  SET e.entity_type = node.entity_type,
60
  e.description = node.description,
61
  e.source_id = node.source_id,
62
+ e.displayName = node.id
63
+ REMOVE e:Entity
64
  WITH e, node
65
  CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
66
  RETURN count(*)
 
103
  # Execute queries in batches
104
  with driver.session() as session:
105
  # Insert nodes in batches
106
+ session.execute_write(
107
+ process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
108
+ )
109
 
110
  # Insert edges in batches
111
+ session.execute_write(
112
+ process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
113
+ )
114
 
115
  # Set displayName and labels
116
  session.run(set_displayname_and_labels_query)
117
 
118
  except Exception as e:
119
  print(f"Error occurred: {e}")
120
+
121
  finally:
122
  driver.close()
123
 
124
+
125
  if __name__ == "__main__":
126
  main()
examples/lightrag_openai_compatible_demo.py CHANGED
@@ -52,6 +52,7 @@ async def test_funcs():
52
 
53
  # asyncio.run(test_funcs())
54
 
 
55
  async def main():
56
  try:
57
  embedding_dimension = await get_embedding_dim()
@@ -61,35 +62,47 @@ async def main():
61
  working_dir=WORKING_DIR,
62
  llm_model_func=llm_model_func,
63
  embedding_func=EmbeddingFunc(
64
- embedding_dim=embedding_dimension, max_token_size=8192, func=embedding_func
 
 
65
  ),
66
  )
67
 
68
-
69
  with open("./book.txt", "r", encoding="utf-8") as f:
70
  rag.insert(f.read())
71
 
72
  # Perform naive search
73
  print(
74
- rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
 
 
75
  )
76
 
77
  # Perform local search
78
  print(
79
- rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
 
 
80
  )
81
 
82
  # Perform global search
83
  print(
84
- rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
 
 
 
85
  )
86
 
87
  # Perform hybrid search
88
  print(
89
- rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
 
 
 
90
  )
91
  except Exception as e:
92
  print(f"An error occurred: {e}")
93
 
 
94
  if __name__ == "__main__":
95
- asyncio.run(main())
 
52
 
53
  # asyncio.run(test_funcs())
54
 
55
+
56
  async def main():
57
  try:
58
  embedding_dimension = await get_embedding_dim()
 
62
  working_dir=WORKING_DIR,
63
  llm_model_func=llm_model_func,
64
  embedding_func=EmbeddingFunc(
65
+ embedding_dim=embedding_dimension,
66
+ max_token_size=8192,
67
+ func=embedding_func,
68
  ),
69
  )
70
 
 
71
  with open("./book.txt", "r", encoding="utf-8") as f:
72
  rag.insert(f.read())
73
 
74
  # Perform naive search
75
  print(
76
+ rag.query(
77
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
78
+ )
79
  )
80
 
81
  # Perform local search
82
  print(
83
+ rag.query(
84
+ "What are the top themes in this story?", param=QueryParam(mode="local")
85
+ )
86
  )
87
 
88
  # Perform global search
89
  print(
90
+ rag.query(
91
+ "What are the top themes in this story?",
92
+ param=QueryParam(mode="global"),
93
+ )
94
  )
95
 
96
  # Perform hybrid search
97
  print(
98
+ rag.query(
99
+ "What are the top themes in this story?",
100
+ param=QueryParam(mode="hybrid"),
101
+ )
102
  )
103
  except Exception as e:
104
  print(f"An error occurred: {e}")
105
 
106
+
107
  if __name__ == "__main__":
108
+ asyncio.run(main())
examples/lightrag_siliconcloud_demo.py CHANGED
@@ -30,7 +30,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
30
  texts,
31
  model="netease-youdao/bce-embedding-base_v1",
32
  api_key=os.getenv("SILICONFLOW_API_KEY"),
33
- max_token_size=512
34
  )
35
 
36
 
 
30
  texts,
31
  model="netease-youdao/bce-embedding-base_v1",
32
  api_key=os.getenv("SILICONFLOW_API_KEY"),
33
+ max_token_size=512,
34
  )
35
 
36
 
examples/vram_management_demo.py CHANGED
@@ -27,11 +27,12 @@ rag = LightRAG(
27
  # Read all .txt files from the TEXT_FILES_DIR directory
28
  texts = []
29
  for filename in os.listdir(TEXT_FILES_DIR):
30
- if filename.endswith('.txt'):
31
  file_path = os.path.join(TEXT_FILES_DIR, filename)
32
- with open(file_path, 'r', encoding='utf-8') as file:
33
  texts.append(file.read())
34
 
 
35
  # Batch insert texts into LightRAG with a retry mechanism
36
  def insert_texts_with_retry(rag, texts, retries=3, delay=5):
37
  for _ in range(retries):
@@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5):
39
  rag.insert(texts)
40
  return
41
  except Exception as e:
42
- print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...")
 
 
43
  time.sleep(delay)
44
  raise RuntimeError("Failed to insert texts after multiple retries.")
45
 
 
46
  insert_texts_with_retry(rag, texts)
47
 
48
  # Perform different types of queries and handle potential errors
49
  try:
50
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
 
 
 
 
51
  except Exception as e:
52
  print(f"Error performing naive search: {e}")
53
 
54
  try:
55
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
 
 
 
 
56
  except Exception as e:
57
  print(f"Error performing local search: {e}")
58
 
59
  try:
60
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
 
 
 
 
61
  except Exception as e:
62
  print(f"Error performing global search: {e}")
63
 
64
  try:
65
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
 
 
 
 
66
  except Exception as e:
67
  print(f"Error performing hybrid search: {e}")
68
 
 
69
  # Function to clear VRAM resources
70
  def clear_vram():
71
  os.system("sudo nvidia-smi --gpu-reset")
72
 
 
73
  # Regularly clear VRAM to prevent overflow
74
  clear_vram_interval = 3600 # Clear once every hour
75
  start_time = time.time()
 
27
  # Read all .txt files from the TEXT_FILES_DIR directory
28
  texts = []
29
  for filename in os.listdir(TEXT_FILES_DIR):
30
+ if filename.endswith(".txt"):
31
  file_path = os.path.join(TEXT_FILES_DIR, filename)
32
+ with open(file_path, "r", encoding="utf-8") as file:
33
  texts.append(file.read())
34
 
35
+
36
  # Batch insert texts into LightRAG with a retry mechanism
37
  def insert_texts_with_retry(rag, texts, retries=3, delay=5):
38
  for _ in range(retries):
 
40
  rag.insert(texts)
41
  return
42
  except Exception as e:
43
+ print(
44
+ f"Error occurred during insertion: {e}. Retrying in {delay} seconds..."
45
+ )
46
  time.sleep(delay)
47
  raise RuntimeError("Failed to insert texts after multiple retries.")
48
 
49
+
50
  insert_texts_with_retry(rag, texts)
51
 
52
  # Perform different types of queries and handle potential errors
53
  try:
54
+ print(
55
+ rag.query(
56
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
57
+ )
58
+ )
59
  except Exception as e:
60
  print(f"Error performing naive search: {e}")
61
 
62
  try:
63
+ print(
64
+ rag.query(
65
+ "What are the top themes in this story?", param=QueryParam(mode="local")
66
+ )
67
+ )
68
  except Exception as e:
69
  print(f"Error performing local search: {e}")
70
 
71
  try:
72
+ print(
73
+ rag.query(
74
+ "What are the top themes in this story?", param=QueryParam(mode="global")
75
+ )
76
+ )
77
  except Exception as e:
78
  print(f"Error performing global search: {e}")
79
 
80
  try:
81
+ print(
82
+ rag.query(
83
+ "What are the top themes in this story?", param=QueryParam(mode="hybrid")
84
+ )
85
+ )
86
  except Exception as e:
87
  print(f"Error performing hybrid search: {e}")
88
 
89
+
90
  # Function to clear VRAM resources
91
  def clear_vram():
92
  os.system("sudo nvidia-smi --gpu-reset")
93
 
94
+
95
  # Regularly clear VRAM to prevent overflow
96
  clear_vram_interval = 3600 # Clear once every hour
97
  start_time = time.time()
lightrag/llm.py CHANGED
@@ -7,7 +7,13 @@ import aiohttp
7
  import numpy as np
8
  import ollama
9
 
10
- from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI
 
 
 
 
 
 
11
 
12
  import base64
13
  import struct
@@ -70,26 +76,31 @@ async def openai_complete_if_cache(
70
  )
71
  return response.choices[0].message.content
72
 
 
73
  @retry(
74
  stop=stop_after_attempt(3),
75
  wait=wait_exponential(multiplier=1, min=4, max=10),
76
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
77
  )
78
- async def azure_openai_complete_if_cache(model,
 
79
  prompt,
80
  system_prompt=None,
81
  history_messages=[],
82
  base_url=None,
83
  api_key=None,
84
- **kwargs):
 
85
  if api_key:
86
  os.environ["AZURE_OPENAI_API_KEY"] = api_key
87
  if base_url:
88
  os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
89
 
90
- openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
91
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
92
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
 
 
93
 
94
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
95
  messages = []
@@ -114,6 +125,7 @@ async def azure_openai_complete_if_cache(model,
114
  )
115
  return response.choices[0].message.content
116
 
 
117
  class BedrockError(Exception):
118
  """Generic error for issues related to Amazon Bedrock"""
119
 
@@ -205,8 +217,12 @@ async def bedrock_complete_if_cache(
205
 
206
  @lru_cache(maxsize=1)
207
  def initialize_hf_model(model_name):
208
- hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
209
- hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
 
 
 
 
210
  if hf_tokenizer.pad_token is None:
211
  hf_tokenizer.pad_token = hf_tokenizer.eos_token
212
 
@@ -328,8 +344,9 @@ async def gpt_4o_mini_complete(
328
  **kwargs,
329
  )
330
 
 
331
  async def azure_openai_complete(
332
- prompt, system_prompt=None, history_messages=[], **kwargs
333
  ) -> str:
334
  return await azure_openai_complete_if_cache(
335
  "conversation-4o-mini",
@@ -339,6 +356,7 @@ async def azure_openai_complete(
339
  **kwargs,
340
  )
341
 
 
342
  async def bedrock_complete(
343
  prompt, system_prompt=None, history_messages=[], **kwargs
344
  ) -> str:
@@ -418,9 +436,11 @@ async def azure_openai_embedding(
418
  if base_url:
419
  os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
420
 
421
- openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
422
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
423
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
 
 
424
 
425
  response = await openai_async_client.embeddings.create(
426
  model=model, input=texts, encoding_format="float"
@@ -440,35 +460,28 @@ async def siliconcloud_embedding(
440
  max_token_size: int = 512,
441
  api_key: str = None,
442
  ) -> np.ndarray:
443
- if api_key and not api_key.startswith('Bearer '):
444
- api_key = 'Bearer ' + api_key
445
 
446
- headers = {
447
- "Authorization": api_key,
448
- "Content-Type": "application/json"
449
- }
450
 
451
  truncate_texts = [text[0:max_token_size] for text in texts]
452
 
453
- payload = {
454
- "model": model,
455
- "input": truncate_texts,
456
- "encoding_format": "base64"
457
- }
458
 
459
  base64_strings = []
460
  async with aiohttp.ClientSession() as session:
461
  async with session.post(base_url, headers=headers, json=payload) as response:
462
  content = await response.json()
463
- if 'code' in content:
464
  raise ValueError(content)
465
- base64_strings = [item['embedding'] for item in content['data']]
466
-
467
  embeddings = []
468
  for string in base64_strings:
469
  decode_bytes = base64.b64decode(string)
470
  n = len(decode_bytes) // 4
471
- float_array = struct.unpack('<' + 'f' * n, decode_bytes)
472
  embeddings.append(float_array)
473
  return np.array(embeddings)
474
 
@@ -563,6 +576,7 @@ async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
563
 
564
  return embed_text
565
 
 
566
  class Model(BaseModel):
567
  """
568
  This is a Pydantic model class named 'Model' that is used to define a custom language model.
@@ -580,14 +594,20 @@ class Model(BaseModel):
580
  The 'kwargs' dictionary contains the model name and API key to be passed to the function.
581
  """
582
 
583
- gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
584
- kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
 
 
 
 
 
 
585
 
586
  class Config:
587
  arbitrary_types_allowed = True
588
 
589
 
590
- class MultiModel():
591
  """
592
  Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
593
  Could also be used for spliting across diffrent models or providers.
@@ -611,26 +631,31 @@ class MultiModel():
611
  )
612
  ```
613
  """
 
614
  def __init__(self, models: List[Model]):
615
  self._models = models
616
  self._current_model = 0
617
-
618
  def _next_model(self):
619
  self._current_model = (self._current_model + 1) % len(self._models)
620
  return self._models[self._current_model]
621
 
622
  async def llm_model_func(
623
- self,
624
- prompt, system_prompt=None, history_messages=[], **kwargs
625
  ) -> str:
626
- kwargs.pop("model", None) # stop from overwriting the custom model name
627
  next_model = self._next_model()
628
- args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
629
-
630
- return await next_model.gen_func(
631
- **args
 
 
632
  )
633
 
 
 
 
634
  if __name__ == "__main__":
635
  import asyncio
636
 
 
7
  import numpy as np
8
  import ollama
9
 
10
+ from openai import (
11
+ AsyncOpenAI,
12
+ APIConnectionError,
13
+ RateLimitError,
14
+ Timeout,
15
+ AsyncAzureOpenAI,
16
+ )
17
 
18
  import base64
19
  import struct
 
76
  )
77
  return response.choices[0].message.content
78
 
79
+
80
  @retry(
81
  stop=stop_after_attempt(3),
82
  wait=wait_exponential(multiplier=1, min=4, max=10),
83
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
84
  )
85
+ async def azure_openai_complete_if_cache(
86
+ model,
87
  prompt,
88
  system_prompt=None,
89
  history_messages=[],
90
  base_url=None,
91
  api_key=None,
92
+ **kwargs,
93
+ ):
94
  if api_key:
95
  os.environ["AZURE_OPENAI_API_KEY"] = api_key
96
  if base_url:
97
  os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
98
 
99
+ openai_async_client = AsyncAzureOpenAI(
100
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
101
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
102
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
103
+ )
104
 
105
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
106
  messages = []
 
125
  )
126
  return response.choices[0].message.content
127
 
128
+
129
  class BedrockError(Exception):
130
  """Generic error for issues related to Amazon Bedrock"""
131
 
 
217
 
218
  @lru_cache(maxsize=1)
219
  def initialize_hf_model(model_name):
220
+ hf_tokenizer = AutoTokenizer.from_pretrained(
221
+ model_name, device_map="auto", trust_remote_code=True
222
+ )
223
+ hf_model = AutoModelForCausalLM.from_pretrained(
224
+ model_name, device_map="auto", trust_remote_code=True
225
+ )
226
  if hf_tokenizer.pad_token is None:
227
  hf_tokenizer.pad_token = hf_tokenizer.eos_token
228
 
 
344
  **kwargs,
345
  )
346
 
347
+
348
  async def azure_openai_complete(
349
+ prompt, system_prompt=None, history_messages=[], **kwargs
350
  ) -> str:
351
  return await azure_openai_complete_if_cache(
352
  "conversation-4o-mini",
 
356
  **kwargs,
357
  )
358
 
359
+
360
  async def bedrock_complete(
361
  prompt, system_prompt=None, history_messages=[], **kwargs
362
  ) -> str:
 
436
  if base_url:
437
  os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
438
 
439
+ openai_async_client = AsyncAzureOpenAI(
440
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
441
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
442
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
443
+ )
444
 
445
  response = await openai_async_client.embeddings.create(
446
  model=model, input=texts, encoding_format="float"
 
460
  max_token_size: int = 512,
461
  api_key: str = None,
462
  ) -> np.ndarray:
463
+ if api_key and not api_key.startswith("Bearer "):
464
+ api_key = "Bearer " + api_key
465
 
466
+ headers = {"Authorization": api_key, "Content-Type": "application/json"}
 
 
 
467
 
468
  truncate_texts = [text[0:max_token_size] for text in texts]
469
 
470
+ payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
 
 
 
 
471
 
472
  base64_strings = []
473
  async with aiohttp.ClientSession() as session:
474
  async with session.post(base_url, headers=headers, json=payload) as response:
475
  content = await response.json()
476
+ if "code" in content:
477
  raise ValueError(content)
478
+ base64_strings = [item["embedding"] for item in content["data"]]
479
+
480
  embeddings = []
481
  for string in base64_strings:
482
  decode_bytes = base64.b64decode(string)
483
  n = len(decode_bytes) // 4
484
+ float_array = struct.unpack("<" + "f" * n, decode_bytes)
485
  embeddings.append(float_array)
486
  return np.array(embeddings)
487
 
 
576
 
577
  return embed_text
578
 
579
+
580
  class Model(BaseModel):
581
  """
582
  This is a Pydantic model class named 'Model' that is used to define a custom language model.
 
594
  The 'kwargs' dictionary contains the model name and API key to be passed to the function.
595
  """
596
 
597
+ gen_func: Callable[[Any], str] = Field(
598
+ ...,
599
+ description="A function that generates the response from the llm. The response must be a string",
600
+ )
601
+ kwargs: Dict[str, Any] = Field(
602
+ ...,
603
+ description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
604
+ )
605
 
606
  class Config:
607
  arbitrary_types_allowed = True
608
 
609
 
610
+ class MultiModel:
611
  """
612
  Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
613
  Could also be used for spliting across diffrent models or providers.
 
631
  )
632
  ```
633
  """
634
+
635
  def __init__(self, models: List[Model]):
636
  self._models = models
637
  self._current_model = 0
638
+
639
  def _next_model(self):
640
  self._current_model = (self._current_model + 1) % len(self._models)
641
  return self._models[self._current_model]
642
 
643
  async def llm_model_func(
644
+ self, prompt, system_prompt=None, history_messages=[], **kwargs
 
645
  ) -> str:
646
+ kwargs.pop("model", None) # stop from overwriting the custom model name
647
  next_model = self._next_model()
648
+ args = dict(
649
+ prompt=prompt,
650
+ system_prompt=system_prompt,
651
+ history_messages=history_messages,
652
+ **kwargs,
653
+ **next_model.kwargs,
654
  )
655
 
656
+ return await next_model.gen_func(**args)
657
+
658
+
659
  if __name__ == "__main__":
660
  import asyncio
661
 
lightrag/utils.py CHANGED
@@ -185,6 +185,7 @@ def save_data_to_file(data, file_name):
185
  with open(file_name, "w", encoding="utf-8") as f:
186
  json.dump(data, f, ensure_ascii=False, indent=4)
187
 
 
188
  def xml_to_json(xml_file):
189
  try:
190
  tree = ET.parse(xml_file)
@@ -194,31 +195,42 @@ def xml_to_json(xml_file):
194
  print(f"Root element: {root.tag}")
195
  print(f"Root attributes: {root.attrib}")
196
 
197
- data = {
198
- "nodes": [],
199
- "edges": []
200
- }
201
 
202
  # Use namespace
203
- namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
204
 
205
- for node in root.findall('.//node', namespace):
206
  node_data = {
207
- "id": node.get('id').strip('"'),
208
- "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
209
- "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
210
- "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
 
 
 
 
 
 
211
  }
212
  data["nodes"].append(node_data)
213
 
214
- for edge in root.findall('.//edge', namespace):
215
  edge_data = {
216
- "source": edge.get('source').strip('"'),
217
- "target": edge.get('target').strip('"'),
218
- "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
219
- "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
220
- "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
221
- "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
 
 
 
 
 
 
 
 
222
  }
223
  data["edges"].append(edge_data)
224
 
 
185
  with open(file_name, "w", encoding="utf-8") as f:
186
  json.dump(data, f, ensure_ascii=False, indent=4)
187
 
188
+
189
  def xml_to_json(xml_file):
190
  try:
191
  tree = ET.parse(xml_file)
 
195
  print(f"Root element: {root.tag}")
196
  print(f"Root attributes: {root.attrib}")
197
 
198
+ data = {"nodes": [], "edges": []}
 
 
 
199
 
200
  # Use namespace
201
+ namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
202
 
203
+ for node in root.findall(".//node", namespace):
204
  node_data = {
205
+ "id": node.get("id").strip('"'),
206
+ "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
207
+ if node.find("./data[@key='d0']", namespace) is not None
208
+ else "",
209
+ "description": node.find("./data[@key='d1']", namespace).text
210
+ if node.find("./data[@key='d1']", namespace) is not None
211
+ else "",
212
+ "source_id": node.find("./data[@key='d2']", namespace).text
213
+ if node.find("./data[@key='d2']", namespace) is not None
214
+ else "",
215
  }
216
  data["nodes"].append(node_data)
217
 
218
+ for edge in root.findall(".//edge", namespace):
219
  edge_data = {
220
+ "source": edge.get("source").strip('"'),
221
+ "target": edge.get("target").strip('"'),
222
+ "weight": float(edge.find("./data[@key='d3']", namespace).text)
223
+ if edge.find("./data[@key='d3']", namespace) is not None
224
+ else 0.0,
225
+ "description": edge.find("./data[@key='d4']", namespace).text
226
+ if edge.find("./data[@key='d4']", namespace) is not None
227
+ else "",
228
+ "keywords": edge.find("./data[@key='d5']", namespace).text
229
+ if edge.find("./data[@key='d5']", namespace) is not None
230
+ else "",
231
+ "source_id": edge.find("./data[@key='d6']", namespace).text
232
+ if edge.find("./data[@key='d6']", namespace) is not None
233
+ else "",
234
  }
235
  data["edges"].append(edge_data)
236
 
requirements.txt CHANGED
@@ -1,15 +1,15 @@
1
  accelerate
2
  aioboto3
 
3
  graspologic
4
  hnswlib
5
  nano-vectordb
6
  networkx
7
  ollama
8
  openai
 
9
  tenacity
10
  tiktoken
11
  torch
12
  transformers
13
  xxhash
14
- pyvis
15
- aiohttp
 
1
  accelerate
2
  aioboto3
3
+ aiohttp
4
  graspologic
5
  hnswlib
6
  nano-vectordb
7
  networkx
8
  ollama
9
  openai
10
+ pyvis
11
  tenacity
12
  tiktoken
13
  torch
14
  transformers
15
  xxhash