wiltshirek commited on
Commit
f632fdf
·
2 Parent(s): ad575ba ddcc625

Merge branch 'main' into main

Browse files
.github/workflows/linting.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Linting and Formatting
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - main
10
+
11
+ jobs:
12
+ lint-and-format:
13
+ runs-on: ubuntu-latest
14
+
15
+ steps:
16
+ - name: Checkout code
17
+ uses: actions/checkout@v2
18
+
19
+ - name: Set up Python
20
+ uses: actions/setup-python@v2
21
+ with:
22
+ python-version: '3.x'
23
+
24
+ - name: Install dependencies
25
+ run: |
26
+ python -m pip install --upgrade pip
27
+ pip install pre-commit
28
+
29
+ - name: Run pre-commit
30
+ run: pre-commit run --all-files
.gitignore CHANGED
@@ -8,4 +8,5 @@ dist/
8
  env/
9
  local_neo4jWorkDir/
10
  neo4jWorkDir/
11
- ignore_this.txt
 
 
8
  env/
9
  local_neo4jWorkDir/
10
  neo4jWorkDir/
11
+ ignore_this.txt
12
+ .venv/
README.md CHANGED
@@ -8,7 +8,7 @@
8
  <a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
9
  <a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
10
  <a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
11
- <a href='https://discord.gg/mvsfu2Tg'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
12
  </p>
13
  <p>
14
  <img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
@@ -22,11 +22,17 @@ This repository hosts the code of LightRAG. The structure of this code is based
22
  </div>
23
 
24
  ## 🎉 News
25
- - [x] [2024.10.20]🎯🎯📢📢We’ve added a new feature to LightRAG: Graph Visualization.
26
- - [x] [2024.10.18]🎯🎯📢📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
27
- - [x] [2024.10.17]🎯🎯📢📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
28
- - [x] [2024.10.16]🎯🎯📢📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
29
- - [x] [2024.10.15]🎯🎯📢📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
 
 
 
 
 
 
30
 
31
  ## Install
32
 
@@ -58,8 +64,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
58
 
59
  #########
60
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
61
- # import nest_asyncio
62
- # nest_asyncio.apply()
63
  #########
64
 
65
  WORKING_DIR = "./dickens"
@@ -190,8 +196,11 @@ see test_neo4j.py for a working example.
190
 
191
  <details>
192
  <summary> Using Ollama Models </summary>
193
-
194
- * If you want to use Ollama models, you only need to set LightRAG as follows:
 
 
 
195
 
196
  ```python
197
  from lightrag.llm import ollama_model_complete, ollama_embedding
@@ -213,28 +222,59 @@ rag = LightRAG(
213
  )
214
  ```
215
 
216
- * Increasing the `num_ctx` parameter:
 
 
 
217
 
218
  1. Pull the model:
219
- ```python
220
  ollama pull qwen2
221
  ```
222
 
223
  2. Display the model file:
224
- ```python
225
  ollama show --modelfile qwen2 > Modelfile
226
  ```
227
 
228
  3. Edit the Modelfile by adding the following line:
229
- ```python
230
  PARAMETER num_ctx 32768
231
  ```
232
 
233
  4. Create the modified model:
234
- ```python
235
  ollama create -f Modelfile qwen2m
236
  ```
237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
  </details>
239
 
240
  ### Query Param
@@ -265,12 +305,33 @@ rag.insert(["TEXT1", "TEXT2",...])
265
 
266
  ```python
267
  # Incremental Insert: Insert new documents into an existing LightRAG instance
268
- rag = LightRAG(working_dir="./dickens")
 
 
 
 
 
 
 
 
269
 
270
  with open("./newText.txt") as f:
271
  rag.insert(f.read())
272
  ```
273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  ### Graph Visualization
275
 
276
  <details>
@@ -361,8 +422,8 @@ def main():
361
  SET e.entity_type = node.entity_type,
362
  e.description = node.description,
363
  e.source_id = node.source_id,
364
- e.displayName = node.id
365
- REMOVE e:Entity
366
  WITH e, node
367
  CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
368
  RETURN count(*)
@@ -415,7 +476,7 @@ def main():
415
 
416
  except Exception as e:
417
  print(f"Error occurred: {e}")
418
-
419
  finally:
420
  driver.close()
421
 
@@ -425,6 +486,125 @@ if __name__ == "__main__":
425
 
426
  </details>
427
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
428
  ## Evaluation
429
  ### Dataset
430
  The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
@@ -671,12 +851,14 @@ def extract_queries(file_path):
671
  .
672
  ├── examples
673
  │ ├── batch_eval.py
 
674
  │ ├── graph_visual_with_html.py
675
  │ ├── graph_visual_with_neo4j.py
676
- │ ├── generate_query.py
677
  │ ├── lightrag_azure_openai_demo.py
678
  │ ├── lightrag_bedrock_demo.py
679
  │ ├── lightrag_hf_demo.py
 
680
  │ ├── lightrag_ollama_demo.py
681
  │ ├── lightrag_openai_compatible_demo.py
682
  │ ├── lightrag_openai_demo.py
@@ -693,8 +875,10 @@ def extract_queries(file_path):
693
  │ └── utils.py
694
  ├── reproduce
695
  │ ├── Step_0.py
 
696
  │ ├── Step_1.py
697
  │ ├── Step_2.py
 
698
  │ └── Step_3.py
699
  ├── .gitignore
700
  ├── .pre-commit-config.yaml
@@ -726,3 +910,6 @@ archivePrefix={arXiv},
726
  primaryClass={cs.IR}
727
  }
728
  ```
 
 
 
 
8
  <a href='https://lightrag.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a>
9
  <a href='https://youtu.be/oageL-1I0GE'><img src='https://badges.aleen42.com/src/youtube.svg'></a>
10
  <a href='https://arxiv.org/abs/2410.05779'><img src='https://img.shields.io/badge/arXiv-2410.05779-b31b1b'></a>
11
+ <a href='https://discord.gg/rdE8YVPm'><img src='https://discordapp.com/api/guilds/1296348098003734629/widget.png?style=shield'></a>
12
  </p>
13
  <p>
14
  <img src='https://img.shields.io/github/stars/hkuds/lightrag?color=green&style=social' />
 
22
  </div>
23
 
24
  ## 🎉 News
25
+ - [x] [2024.10.29]🎯📢LightRAG now supports multiple file types, including PDF, DOC, PPT, and CSV via `textract`.
26
+ - [x] [2024.10.20]🎯📢We’ve added a new feature to LightRAG: Graph Visualization.
27
+ - [x] [2024.10.18]🎯📢We’ve added a link to a [LightRAG Introduction Video](https://youtu.be/oageL-1I0GE). Thanks to the author!
28
+ - [x] [2024.10.17]🎯📢We have created a [Discord channel](https://discord.gg/mvsfu2Tg)! Welcome to join for sharing and discussions! 🎉🎉
29
+ - [x] [2024.10.16]🎯📢LightRAG now supports [Ollama models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
30
+ - [x] [2024.10.15]🎯📢LightRAG now supports [Hugging Face models](https://github.com/HKUDS/LightRAG?tab=readme-ov-file#quick-start)!
31
+
32
+ ## Algorithm Flowchart
33
+
34
+ ![LightRAG_Self excalidraw](https://github.com/user-attachments/assets/aa5c4892-2e44-49e6-a116-2403ed80a1a3)
35
+
36
 
37
  ## Install
38
 
 
64
 
65
  #########
66
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
67
+ # import nest_asyncio
68
+ # nest_asyncio.apply()
69
  #########
70
 
71
  WORKING_DIR = "./dickens"
 
196
 
197
  <details>
198
  <summary> Using Ollama Models </summary>
199
+
200
+ ### Overview
201
+ If you want to use Ollama models, you need to pull model you plan to use and embedding model, for example `nomic-embed-text`.
202
+
203
+ Then you only need to set LightRAG as follows:
204
 
205
  ```python
206
  from lightrag.llm import ollama_model_complete, ollama_embedding
 
222
  )
223
  ```
224
 
225
+ ### Increasing context size
226
+ In order for LightRAG to work context should be at least 32k tokens. By default Ollama models have context size of 8k. You can achieve this using one of two ways:
227
+
228
+ #### Increasing the `num_ctx` parameter in Modelfile.
229
 
230
  1. Pull the model:
231
+ ```bash
232
  ollama pull qwen2
233
  ```
234
 
235
  2. Display the model file:
236
+ ```bash
237
  ollama show --modelfile qwen2 > Modelfile
238
  ```
239
 
240
  3. Edit the Modelfile by adding the following line:
241
+ ```bash
242
  PARAMETER num_ctx 32768
243
  ```
244
 
245
  4. Create the modified model:
246
+ ```bash
247
  ollama create -f Modelfile qwen2m
248
  ```
249
 
250
+ #### Setup `num_ctx` via Ollama API.
251
+ Tiy can use `llm_model_kwargs` param to configure ollama:
252
+
253
+ ```python
254
+ rag = LightRAG(
255
+ working_dir=WORKING_DIR,
256
+ llm_model_func=ollama_model_complete, # Use Ollama model for text generation
257
+ llm_model_name='your_model_name', # Your model name
258
+ llm_model_kwargs={"options": {"num_ctx": 32768}},
259
+ # Use Ollama embedding function
260
+ embedding_func=EmbeddingFunc(
261
+ embedding_dim=768,
262
+ max_token_size=8192,
263
+ func=lambda texts: ollama_embedding(
264
+ texts,
265
+ embed_model="nomic-embed-text"
266
+ )
267
+ ),
268
+ )
269
+ ```
270
+ #### Fully functional example
271
+
272
+ There fully functional example `examples/lightrag_ollama_demo.py` that utilizes `gemma2:2b` model, runs only 4 requests in parallel and set context size to 32k.
273
+
274
+ #### Low RAM GPUs
275
+
276
+ In order to run this experiment on low RAM GPU you should select small model and tune context window (increasing context increase memory consumption). For example, running this ollama example on repurposed mining GPU with 6Gb of RAM required to set context size to 26k while using `gemma2:2b`. It was able to find 197 entities and 19 relations on `book.txt`.
277
+
278
  </details>
279
 
280
  ### Query Param
 
305
 
306
  ```python
307
  # Incremental Insert: Insert new documents into an existing LightRAG instance
308
+ rag = LightRAG(
309
+ working_dir=WORKING_DIR,
310
+ llm_model_func=llm_model_func,
311
+ embedding_func=EmbeddingFunc(
312
+ embedding_dim=embedding_dimension,
313
+ max_token_size=8192,
314
+ func=embedding_func,
315
+ ),
316
+ )
317
 
318
  with open("./newText.txt") as f:
319
  rag.insert(f.read())
320
  ```
321
 
322
+ ### Multi-file Type Support
323
+
324
+ The `testract` supports reading file types such as TXT, DOCX, PPTX, CSV, and PDF.
325
+
326
+ ```python
327
+ import textract
328
+
329
+ file_path = 'TEXT.pdf'
330
+ text_content = textract.process(file_path)
331
+
332
+ rag.insert(text_content.decode('utf-8'))
333
+ ```
334
+
335
  ### Graph Visualization
336
 
337
  <details>
 
422
  SET e.entity_type = node.entity_type,
423
  e.description = node.description,
424
  e.source_id = node.source_id,
425
+ e.displayName = node.id
426
+ REMOVE e:Entity
427
  WITH e, node
428
  CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
429
  RETURN count(*)
 
476
 
477
  except Exception as e:
478
  print(f"Error occurred: {e}")
479
+
480
  finally:
481
  driver.close()
482
 
 
486
 
487
  </details>
488
 
489
+ ## API Server Implementation
490
+
491
+ LightRAG also provides a FastAPI-based server implementation for RESTful API access to RAG operations. This allows you to run LightRAG as a service and interact with it through HTTP requests.
492
+
493
+ ### Setting up the API Server
494
+ <details>
495
+ <summary>Click to expand setup instructions</summary>
496
+
497
+ 1. First, ensure you have the required dependencies:
498
+ ```bash
499
+ pip install fastapi uvicorn pydantic
500
+ ```
501
+
502
+ 2. Set up your environment variables:
503
+ ```bash
504
+ export RAG_DIR="your_index_directory" # Optional: Defaults to "index_default"
505
+ ```
506
+
507
+ 3. Run the API server:
508
+ ```bash
509
+ python examples/lightrag_api_openai_compatible_demo.py
510
+ ```
511
+
512
+ The server will start on `http://0.0.0.0:8020`.
513
+ </details>
514
+
515
+ ### API Endpoints
516
+
517
+ The API server provides the following endpoints:
518
+
519
+ #### 1. Query Endpoint
520
+ <details>
521
+ <summary>Click to view Query endpoint details</summary>
522
+
523
+ - **URL:** `/query`
524
+ - **Method:** POST
525
+ - **Body:**
526
+ ```json
527
+ {
528
+ "query": "Your question here",
529
+ "mode": "hybrid" // Can be "naive", "local", "global", or "hybrid"
530
+ }
531
+ ```
532
+ - **Example:**
533
+ ```bash
534
+ curl -X POST "http://127.0.0.1:8020/query" \
535
+ -H "Content-Type: application/json" \
536
+ -d '{"query": "What are the main themes?", "mode": "hybrid"}'
537
+ ```
538
+ </details>
539
+
540
+ #### 2. Insert Text Endpoint
541
+ <details>
542
+ <summary>Click to view Insert Text endpoint details</summary>
543
+
544
+ - **URL:** `/insert`
545
+ - **Method:** POST
546
+ - **Body:**
547
+ ```json
548
+ {
549
+ "text": "Your text content here"
550
+ }
551
+ ```
552
+ - **Example:**
553
+ ```bash
554
+ curl -X POST "http://127.0.0.1:8020/insert" \
555
+ -H "Content-Type: application/json" \
556
+ -d '{"text": "Content to be inserted into RAG"}'
557
+ ```
558
+ </details>
559
+
560
+ #### 3. Insert File Endpoint
561
+ <details>
562
+ <summary>Click to view Insert File endpoint details</summary>
563
+
564
+ - **URL:** `/insert_file`
565
+ - **Method:** POST
566
+ - **Body:**
567
+ ```json
568
+ {
569
+ "file_path": "path/to/your/file.txt"
570
+ }
571
+ ```
572
+ - **Example:**
573
+ ```bash
574
+ curl -X POST "http://127.0.0.1:8020/insert_file" \
575
+ -H "Content-Type: application/json" \
576
+ -d '{"file_path": "./book.txt"}'
577
+ ```
578
+ </details>
579
+
580
+ #### 4. Health Check Endpoint
581
+ <details>
582
+ <summary>Click to view Health Check endpoint details</summary>
583
+
584
+ - **URL:** `/health`
585
+ - **Method:** GET
586
+ - **Example:**
587
+ ```bash
588
+ curl -X GET "http://127.0.0.1:8020/health"
589
+ ```
590
+ </details>
591
+
592
+ ### Configuration
593
+
594
+ The API server can be configured using environment variables:
595
+ - `RAG_DIR`: Directory for storing the RAG index (default: "index_default")
596
+ - API keys and base URLs should be configured in the code for your specific LLM and embedding model providers
597
+
598
+ ### Error Handling
599
+ <details>
600
+ <summary>Click to view error handling details</summary>
601
+
602
+ The API includes comprehensive error handling:
603
+ - File not found errors (404)
604
+ - Processing errors (500)
605
+ - Supports multiple file encodings (UTF-8 and GBK)
606
+ </details>
607
+
608
  ## Evaluation
609
  ### Dataset
610
  The dataset used in LightRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
 
851
  .
852
  ├── examples
853
  │ ├── batch_eval.py
854
+ │ ├── generate_query.py
855
  │ ├── graph_visual_with_html.py
856
  │ ├── graph_visual_with_neo4j.py
857
+ │ ├── lightrag_api_openai_compatible_demo.py
858
  │ ├── lightrag_azure_openai_demo.py
859
  │ ├── lightrag_bedrock_demo.py
860
  │ ├── lightrag_hf_demo.py
861
+ │ ├── lightrag_lmdeploy_demo.py
862
  │ ├── lightrag_ollama_demo.py
863
  │ ├── lightrag_openai_compatible_demo.py
864
  │ ├── lightrag_openai_demo.py
 
875
  │ └── utils.py
876
  ├── reproduce
877
  │ ├── Step_0.py
878
+ │ ├── Step_1_openai_compatible.py
879
  │ ├── Step_1.py
880
  │ ├── Step_2.py
881
+ │ ├── Step_3_openai_compatible.py
882
  │ └── Step_3.py
883
  ├── .gitignore
884
  ├── .pre-commit-config.yaml
 
910
  primaryClass={cs.IR}
911
  }
912
  ```
913
+
914
+
915
+
examples/graph_visual_with_html.py CHANGED
@@ -3,17 +3,17 @@ from pyvis.network import Network
3
  import random
4
 
5
  # Load the GraphML file
6
- G = nx.read_graphml('./dickens/graph_chunk_entity_relation.graphml')
7
 
8
  # Create a Pyvis network
9
- net = Network(notebook=True)
10
 
11
  # Convert NetworkX graph to Pyvis network
12
  net.from_nx(G)
13
 
14
  # Add colors to nodes
15
  for node in net.nodes:
16
- node['color'] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
17
 
18
  # Save and display the network
19
- net.show('knowledge_graph.html')
 
3
  import random
4
 
5
  # Load the GraphML file
6
+ G = nx.read_graphml("./dickens/graph_chunk_entity_relation.graphml")
7
 
8
  # Create a Pyvis network
9
+ net = Network(height="100vh", notebook=True)
10
 
11
  # Convert NetworkX graph to Pyvis network
12
  net.from_nx(G)
13
 
14
  # Add colors to nodes
15
  for node in net.nodes:
16
+ node["color"] = "#{:06x}".format(random.randint(0, 0xFFFFFF))
17
 
18
  # Save and display the network
19
+ net.show("knowledge_graph.html")
examples/graph_visual_with_neo4j.py CHANGED
@@ -13,6 +13,7 @@ NEO4J_URI = "bolt://localhost:7687"
13
  NEO4J_USERNAME = "neo4j"
14
  NEO4J_PASSWORD = "your_password"
15
 
 
16
  def convert_xml_to_json(xml_path, output_path):
17
  """Converts XML file to JSON and saves the output."""
18
  if not os.path.exists(xml_path):
@@ -21,7 +22,7 @@ def convert_xml_to_json(xml_path, output_path):
21
 
22
  json_data = xml_to_json(xml_path)
23
  if json_data:
24
- with open(output_path, 'w', encoding='utf-8') as f:
25
  json.dump(json_data, f, ensure_ascii=False, indent=2)
26
  print(f"JSON file created: {output_path}")
27
  return json_data
@@ -29,16 +30,18 @@ def convert_xml_to_json(xml_path, output_path):
29
  print("Failed to create JSON data")
30
  return None
31
 
 
32
  def process_in_batches(tx, query, data, batch_size):
33
  """Process data in batches and execute the given query."""
34
  for i in range(0, len(data), batch_size):
35
- batch = data[i:i + batch_size]
36
  tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
37
 
 
38
  def main():
39
  # Paths
40
- xml_file = os.path.join(WORKING_DIR, 'graph_chunk_entity_relation.graphml')
41
- json_file = os.path.join(WORKING_DIR, 'graph_data.json')
42
 
43
  # Convert XML to JSON
44
  json_data = convert_xml_to_json(xml_file, json_file)
@@ -46,8 +49,8 @@ def main():
46
  return
47
 
48
  # Load nodes and edges
49
- nodes = json_data.get('nodes', [])
50
- edges = json_data.get('edges', [])
51
 
52
  # Neo4j queries
53
  create_nodes_query = """
@@ -56,8 +59,8 @@ def main():
56
  SET e.entity_type = node.entity_type,
57
  e.description = node.description,
58
  e.source_id = node.source_id,
59
- e.displayName = node.id
60
- REMOVE e:Entity
61
  WITH e, node
62
  CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
63
  RETURN count(*)
@@ -100,19 +103,24 @@ def main():
100
  # Execute queries in batches
101
  with driver.session() as session:
102
  # Insert nodes in batches
103
- session.execute_write(process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES)
 
 
104
 
105
  # Insert edges in batches
106
- session.execute_write(process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES)
 
 
107
 
108
  # Set displayName and labels
109
  session.run(set_displayname_and_labels_query)
110
 
111
  except Exception as e:
112
  print(f"Error occurred: {e}")
113
-
114
  finally:
115
  driver.close()
116
 
 
117
  if __name__ == "__main__":
118
  main()
 
13
  NEO4J_USERNAME = "neo4j"
14
  NEO4J_PASSWORD = "your_password"
15
 
16
+
17
  def convert_xml_to_json(xml_path, output_path):
18
  """Converts XML file to JSON and saves the output."""
19
  if not os.path.exists(xml_path):
 
22
 
23
  json_data = xml_to_json(xml_path)
24
  if json_data:
25
+ with open(output_path, "w", encoding="utf-8") as f:
26
  json.dump(json_data, f, ensure_ascii=False, indent=2)
27
  print(f"JSON file created: {output_path}")
28
  return json_data
 
30
  print("Failed to create JSON data")
31
  return None
32
 
33
+
34
  def process_in_batches(tx, query, data, batch_size):
35
  """Process data in batches and execute the given query."""
36
  for i in range(0, len(data), batch_size):
37
+ batch = data[i : i + batch_size]
38
  tx.run(query, {"nodes": batch} if "nodes" in query else {"edges": batch})
39
 
40
+
41
  def main():
42
  # Paths
43
+ xml_file = os.path.join(WORKING_DIR, "graph_chunk_entity_relation.graphml")
44
+ json_file = os.path.join(WORKING_DIR, "graph_data.json")
45
 
46
  # Convert XML to JSON
47
  json_data = convert_xml_to_json(xml_file, json_file)
 
49
  return
50
 
51
  # Load nodes and edges
52
+ nodes = json_data.get("nodes", [])
53
+ edges = json_data.get("edges", [])
54
 
55
  # Neo4j queries
56
  create_nodes_query = """
 
59
  SET e.entity_type = node.entity_type,
60
  e.description = node.description,
61
  e.source_id = node.source_id,
62
+ e.displayName = node.id
63
+ REMOVE e:Entity
64
  WITH e, node
65
  CALL apoc.create.addLabels(e, [node.entity_type]) YIELD node AS labeledNode
66
  RETURN count(*)
 
103
  # Execute queries in batches
104
  with driver.session() as session:
105
  # Insert nodes in batches
106
+ session.execute_write(
107
+ process_in_batches, create_nodes_query, nodes, BATCH_SIZE_NODES
108
+ )
109
 
110
  # Insert edges in batches
111
+ session.execute_write(
112
+ process_in_batches, create_edges_query, edges, BATCH_SIZE_EDGES
113
+ )
114
 
115
  # Set displayName and labels
116
  session.run(set_displayname_and_labels_query)
117
 
118
  except Exception as e:
119
  print(f"Error occurred: {e}")
120
+
121
  finally:
122
  driver.close()
123
 
124
+
125
  if __name__ == "__main__":
126
  main()
examples/lightrag_api_openai_compatible_demo.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ import os
4
+ from lightrag import LightRAG, QueryParam
5
+ from lightrag.llm import openai_complete_if_cache, openai_embedding
6
+ from lightrag.utils import EmbeddingFunc
7
+ import numpy as np
8
+ from typing import Optional
9
+ import asyncio
10
+ import nest_asyncio
11
+
12
+ # Apply nest_asyncio to solve event loop issues
13
+ nest_asyncio.apply()
14
+
15
+ DEFAULT_RAG_DIR = "index_default"
16
+ app = FastAPI(title="LightRAG API", description="API for RAG operations")
17
+
18
+ # Configure working directory
19
+ WORKING_DIR = os.environ.get("RAG_DIR", f"{DEFAULT_RAG_DIR}")
20
+ print(f"WORKING_DIR: {WORKING_DIR}")
21
+ if not os.path.exists(WORKING_DIR):
22
+ os.mkdir(WORKING_DIR)
23
+
24
+ # LLM model function
25
+
26
+
27
+ async def llm_model_func(
28
+ prompt, system_prompt=None, history_messages=[], **kwargs
29
+ ) -> str:
30
+ return await openai_complete_if_cache(
31
+ "gpt-4o-mini",
32
+ prompt,
33
+ system_prompt=system_prompt,
34
+ history_messages=history_messages,
35
+ api_key="YOUR_API_KEY",
36
+ base_url="YourURL/v1",
37
+ **kwargs,
38
+ )
39
+
40
+
41
+ # Embedding function
42
+
43
+
44
+ async def embedding_func(texts: list[str]) -> np.ndarray:
45
+ return await openai_embedding(
46
+ texts,
47
+ model="text-embedding-3-large",
48
+ api_key="YOUR_API_KEY",
49
+ base_url="YourURL/v1",
50
+ )
51
+
52
+
53
+ # Initialize RAG instance
54
+ rag = LightRAG(
55
+ working_dir=WORKING_DIR,
56
+ llm_model_func=llm_model_func,
57
+ embedding_func=EmbeddingFunc(
58
+ embedding_dim=3072, max_token_size=8192, func=embedding_func
59
+ ),
60
+ )
61
+
62
+ # Data models
63
+
64
+
65
+ class QueryRequest(BaseModel):
66
+ query: str
67
+ mode: str = "hybrid"
68
+
69
+
70
+ class InsertRequest(BaseModel):
71
+ text: str
72
+
73
+
74
+ class InsertFileRequest(BaseModel):
75
+ file_path: str
76
+
77
+
78
+ class Response(BaseModel):
79
+ status: str
80
+ data: Optional[str] = None
81
+ message: Optional[str] = None
82
+
83
+
84
+ # API routes
85
+
86
+
87
+ @app.post("/query", response_model=Response)
88
+ async def query_endpoint(request: QueryRequest):
89
+ try:
90
+ loop = asyncio.get_event_loop()
91
+ result = await loop.run_in_executor(
92
+ None, lambda: rag.query(request.query, param=QueryParam(mode=request.mode))
93
+ )
94
+ return Response(status="success", data=result)
95
+ except Exception as e:
96
+ raise HTTPException(status_code=500, detail=str(e))
97
+
98
+
99
+ @app.post("/insert", response_model=Response)
100
+ async def insert_endpoint(request: InsertRequest):
101
+ try:
102
+ loop = asyncio.get_event_loop()
103
+ await loop.run_in_executor(None, lambda: rag.insert(request.text))
104
+ return Response(status="success", message="Text inserted successfully")
105
+ except Exception as e:
106
+ raise HTTPException(status_code=500, detail=str(e))
107
+
108
+
109
+ @app.post("/insert_file", response_model=Response)
110
+ async def insert_file(request: InsertFileRequest):
111
+ try:
112
+ # Check if file exists
113
+ if not os.path.exists(request.file_path):
114
+ raise HTTPException(
115
+ status_code=404, detail=f"File not found: {request.file_path}"
116
+ )
117
+
118
+ # Read file content
119
+ try:
120
+ with open(request.file_path, "r", encoding="utf-8") as f:
121
+ content = f.read()
122
+ except UnicodeDecodeError:
123
+ # If UTF-8 decoding fails, try other encodings
124
+ with open(request.file_path, "r", encoding="gbk") as f:
125
+ content = f.read()
126
+
127
+ # Insert file content
128
+ loop = asyncio.get_event_loop()
129
+ await loop.run_in_executor(None, lambda: rag.insert(content))
130
+
131
+ return Response(
132
+ status="success",
133
+ message=f"File content from {request.file_path} inserted successfully",
134
+ )
135
+ except Exception as e:
136
+ raise HTTPException(status_code=500, detail=str(e))
137
+
138
+
139
+ @app.get("/health")
140
+ async def health_check():
141
+ return {"status": "healthy"}
142
+
143
+
144
+ if __name__ == "__main__":
145
+ import uvicorn
146
+
147
+ uvicorn.run(app, host="0.0.0.0", port=8020)
148
+
149
+ # Usage example
150
+ # To run the server, use the following command in your terminal:
151
+ # python lightrag_api_openai_compatible_demo.py
152
+
153
+ # Example requests:
154
+ # 1. Query:
155
+ # curl -X POST "http://127.0.0.1:8020/query" -H "Content-Type: application/json" -d '{"query": "your query here", "mode": "hybrid"}'
156
+
157
+ # 2. Insert text:
158
+ # curl -X POST "http://127.0.0.1:8020/insert" -H "Content-Type: application/json" -d '{"text": "your text here"}'
159
+
160
+ # 3. Insert file:
161
+ # curl -X POST "http://127.0.0.1:8020/insert_file" -H "Content-Type: application/json" -d '{"file_path": "path/to/your/file.txt"}'
162
+
163
+ # 4. Health check:
164
+ # curl -X GET "http://127.0.0.1:8020/health"
examples/lightrag_lmdeploy_demo.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm import lmdeploy_model_if_cache, hf_embedding
5
+ from lightrag.utils import EmbeddingFunc
6
+ from transformers import AutoModel, AutoTokenizer
7
+
8
+ WORKING_DIR = "./dickens"
9
+
10
+ if not os.path.exists(WORKING_DIR):
11
+ os.mkdir(WORKING_DIR)
12
+
13
+
14
+ async def lmdeploy_model_complete(
15
+ prompt=None, system_prompt=None, history_messages=[], **kwargs
16
+ ) -> str:
17
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
18
+ return await lmdeploy_model_if_cache(
19
+ model_name,
20
+ prompt,
21
+ system_prompt=system_prompt,
22
+ history_messages=history_messages,
23
+ ## please specify chat_template if your local path does not follow original HF file name,
24
+ ## or model_name is a pytorch model on huggingface.co,
25
+ ## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py
26
+ ## for a list of chat_template available in lmdeploy.
27
+ chat_template="llama3",
28
+ # model_format ='awq', # if you are using awq quantization model.
29
+ # quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8.
30
+ **kwargs,
31
+ )
32
+
33
+
34
+ rag = LightRAG(
35
+ working_dir=WORKING_DIR,
36
+ llm_model_func=lmdeploy_model_complete,
37
+ llm_model_name="meta-llama/Llama-3.1-8B-Instruct", # please use definite path for local model
38
+ embedding_func=EmbeddingFunc(
39
+ embedding_dim=384,
40
+ max_token_size=5000,
41
+ func=lambda texts: hf_embedding(
42
+ texts,
43
+ tokenizer=AutoTokenizer.from_pretrained(
44
+ "sentence-transformers/all-MiniLM-L6-v2"
45
+ ),
46
+ embed_model=AutoModel.from_pretrained(
47
+ "sentence-transformers/all-MiniLM-L6-v2"
48
+ ),
49
+ ),
50
+ ),
51
+ )
52
+
53
+
54
+ with open("./book.txt", "r", encoding="utf-8") as f:
55
+ rag.insert(f.read())
56
+
57
+ # Perform naive search
58
+ print(
59
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
60
+ )
61
+
62
+ # Perform local search
63
+ print(
64
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
65
+ )
66
+
67
+ # Perform global search
68
+ print(
69
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
70
+ )
71
+
72
+ # Perform hybrid search
73
+ print(
74
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
75
+ )
examples/lightrag_ollama_demo.py CHANGED
@@ -1,26 +1,32 @@
1
  import os
2
-
3
  from lightrag import LightRAG, QueryParam
4
  from lightrag.llm import ollama_model_complete, ollama_embedding
5
  from lightrag.utils import EmbeddingFunc
6
 
7
  WORKING_DIR = "./dickens"
8
 
 
 
9
  if not os.path.exists(WORKING_DIR):
10
  os.mkdir(WORKING_DIR)
11
 
12
  rag = LightRAG(
13
  working_dir=WORKING_DIR,
14
  llm_model_func=ollama_model_complete,
15
- llm_model_name="your_model_name",
 
 
 
16
  embedding_func=EmbeddingFunc(
17
  embedding_dim=768,
18
  max_token_size=8192,
19
- func=lambda texts: ollama_embedding(texts, embed_model="nomic-embed-text"),
 
 
20
  ),
21
  )
22
 
23
-
24
  with open("./book.txt", "r", encoding="utf-8") as f:
25
  rag.insert(f.read())
26
 
 
1
  import os
2
+ import logging
3
  from lightrag import LightRAG, QueryParam
4
  from lightrag.llm import ollama_model_complete, ollama_embedding
5
  from lightrag.utils import EmbeddingFunc
6
 
7
  WORKING_DIR = "./dickens"
8
 
9
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
10
+
11
  if not os.path.exists(WORKING_DIR):
12
  os.mkdir(WORKING_DIR)
13
 
14
  rag = LightRAG(
15
  working_dir=WORKING_DIR,
16
  llm_model_func=ollama_model_complete,
17
+ llm_model_name="gemma2:2b",
18
+ llm_model_max_async=4,
19
+ llm_model_max_token_size=32768,
20
+ llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
21
  embedding_func=EmbeddingFunc(
22
  embedding_dim=768,
23
  max_token_size=8192,
24
+ func=lambda texts: ollama_embedding(
25
+ texts, embed_model="nomic-embed-text", host="http://localhost:11434"
26
+ ),
27
  ),
28
  )
29
 
 
30
  with open("./book.txt", "r", encoding="utf-8") as f:
31
  rag.insert(f.read())
32
 
examples/lightrag_openai_compatible_demo.py CHANGED
@@ -34,6 +34,13 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
34
  )
35
 
36
 
 
 
 
 
 
 
 
37
  # function test
38
  async def test_funcs():
39
  result = await llm_model_func("How are you?")
@@ -43,37 +50,59 @@ async def test_funcs():
43
  print("embedding_func: ", result)
44
 
45
 
46
- asyncio.run(test_funcs())
47
-
48
-
49
- rag = LightRAG(
50
- working_dir=WORKING_DIR,
51
- llm_model_func=llm_model_func,
52
- embedding_func=EmbeddingFunc(
53
- embedding_dim=4096, max_token_size=8192, func=embedding_func
54
- ),
55
- )
56
-
57
-
58
- with open("./book.txt", "r", encoding="utf-8") as f:
59
- rag.insert(f.read())
60
-
61
- # Perform naive search
62
- print(
63
- rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
64
- )
65
-
66
- # Perform local search
67
- print(
68
- rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
69
- )
70
-
71
- # Perform global search
72
- print(
73
- rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
74
- )
75
-
76
- # Perform hybrid search
77
- print(
78
- rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  )
35
 
36
 
37
+ async def get_embedding_dim():
38
+ test_text = ["This is a test sentence."]
39
+ embedding = await embedding_func(test_text)
40
+ embedding_dim = embedding.shape[1]
41
+ return embedding_dim
42
+
43
+
44
  # function test
45
  async def test_funcs():
46
  result = await llm_model_func("How are you?")
 
50
  print("embedding_func: ", result)
51
 
52
 
53
+ # asyncio.run(test_funcs())
54
+
55
+
56
+ async def main():
57
+ try:
58
+ embedding_dimension = await get_embedding_dim()
59
+ print(f"Detected embedding dimension: {embedding_dimension}")
60
+
61
+ rag = LightRAG(
62
+ working_dir=WORKING_DIR,
63
+ llm_model_func=llm_model_func,
64
+ embedding_func=EmbeddingFunc(
65
+ embedding_dim=embedding_dimension,
66
+ max_token_size=8192,
67
+ func=embedding_func,
68
+ ),
69
+ )
70
+
71
+ with open("./book.txt", "r", encoding="utf-8") as f:
72
+ await rag.ainsert(f.read())
73
+
74
+ # Perform naive search
75
+ print(
76
+ await rag.aquery(
77
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
78
+ )
79
+ )
80
+
81
+ # Perform local search
82
+ print(
83
+ await rag.aquery(
84
+ "What are the top themes in this story?", param=QueryParam(mode="local")
85
+ )
86
+ )
87
+
88
+ # Perform global search
89
+ print(
90
+ await rag.aquery(
91
+ "What are the top themes in this story?",
92
+ param=QueryParam(mode="global"),
93
+ )
94
+ )
95
+
96
+ # Perform hybrid search
97
+ print(
98
+ await rag.aquery(
99
+ "What are the top themes in this story?",
100
+ param=QueryParam(mode="hybrid"),
101
+ )
102
+ )
103
+ except Exception as e:
104
+ print(f"An error occurred: {e}")
105
+
106
+
107
+ if __name__ == "__main__":
108
+ asyncio.run(main())
examples/lightrag_siliconcloud_demo.py CHANGED
@@ -30,7 +30,7 @@ async def embedding_func(texts: list[str]) -> np.ndarray:
30
  texts,
31
  model="netease-youdao/bce-embedding-base_v1",
32
  api_key=os.getenv("SILICONFLOW_API_KEY"),
33
- max_token_size=512
34
  )
35
 
36
 
 
30
  texts,
31
  model="netease-youdao/bce-embedding-base_v1",
32
  api_key=os.getenv("SILICONFLOW_API_KEY"),
33
+ max_token_size=512,
34
  )
35
 
36
 
examples/vram_management_demo.py CHANGED
@@ -27,11 +27,12 @@ rag = LightRAG(
27
  # Read all .txt files from the TEXT_FILES_DIR directory
28
  texts = []
29
  for filename in os.listdir(TEXT_FILES_DIR):
30
- if filename.endswith('.txt'):
31
  file_path = os.path.join(TEXT_FILES_DIR, filename)
32
- with open(file_path, 'r', encoding='utf-8') as file:
33
  texts.append(file.read())
34
 
 
35
  # Batch insert texts into LightRAG with a retry mechanism
36
  def insert_texts_with_retry(rag, texts, retries=3, delay=5):
37
  for _ in range(retries):
@@ -39,37 +40,58 @@ def insert_texts_with_retry(rag, texts, retries=3, delay=5):
39
  rag.insert(texts)
40
  return
41
  except Exception as e:
42
- print(f"Error occurred during insertion: {e}. Retrying in {delay} seconds...")
 
 
43
  time.sleep(delay)
44
  raise RuntimeError("Failed to insert texts after multiple retries.")
45
 
 
46
  insert_texts_with_retry(rag, texts)
47
 
48
  # Perform different types of queries and handle potential errors
49
  try:
50
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
 
 
 
 
51
  except Exception as e:
52
  print(f"Error performing naive search: {e}")
53
 
54
  try:
55
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
 
 
 
 
56
  except Exception as e:
57
  print(f"Error performing local search: {e}")
58
 
59
  try:
60
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
 
 
 
 
61
  except Exception as e:
62
  print(f"Error performing global search: {e}")
63
 
64
  try:
65
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
 
 
 
 
66
  except Exception as e:
67
  print(f"Error performing hybrid search: {e}")
68
 
 
69
  # Function to clear VRAM resources
70
  def clear_vram():
71
  os.system("sudo nvidia-smi --gpu-reset")
72
 
 
73
  # Regularly clear VRAM to prevent overflow
74
  clear_vram_interval = 3600 # Clear once every hour
75
  start_time = time.time()
 
27
  # Read all .txt files from the TEXT_FILES_DIR directory
28
  texts = []
29
  for filename in os.listdir(TEXT_FILES_DIR):
30
+ if filename.endswith(".txt"):
31
  file_path = os.path.join(TEXT_FILES_DIR, filename)
32
+ with open(file_path, "r", encoding="utf-8") as file:
33
  texts.append(file.read())
34
 
35
+
36
  # Batch insert texts into LightRAG with a retry mechanism
37
  def insert_texts_with_retry(rag, texts, retries=3, delay=5):
38
  for _ in range(retries):
 
40
  rag.insert(texts)
41
  return
42
  except Exception as e:
43
+ print(
44
+ f"Error occurred during insertion: {e}. Retrying in {delay} seconds..."
45
+ )
46
  time.sleep(delay)
47
  raise RuntimeError("Failed to insert texts after multiple retries.")
48
 
49
+
50
  insert_texts_with_retry(rag, texts)
51
 
52
  # Perform different types of queries and handle potential errors
53
  try:
54
+ print(
55
+ rag.query(
56
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
57
+ )
58
+ )
59
  except Exception as e:
60
  print(f"Error performing naive search: {e}")
61
 
62
  try:
63
+ print(
64
+ rag.query(
65
+ "What are the top themes in this story?", param=QueryParam(mode="local")
66
+ )
67
+ )
68
  except Exception as e:
69
  print(f"Error performing local search: {e}")
70
 
71
  try:
72
+ print(
73
+ rag.query(
74
+ "What are the top themes in this story?", param=QueryParam(mode="global")
75
+ )
76
+ )
77
  except Exception as e:
78
  print(f"Error performing global search: {e}")
79
 
80
  try:
81
+ print(
82
+ rag.query(
83
+ "What are the top themes in this story?", param=QueryParam(mode="hybrid")
84
+ )
85
+ )
86
  except Exception as e:
87
  print(f"Error performing hybrid search: {e}")
88
 
89
+
90
  # Function to clear VRAM resources
91
  def clear_vram():
92
  os.system("sudo nvidia-smi --gpu-reset")
93
 
94
+
95
  # Regularly clear VRAM to prevent overflow
96
  clear_vram_interval = 3600 # Clear once every hour
97
  start_time = time.time()
lightrag/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
2
 
3
- __version__ = "0.0.7"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
 
1
  from .lightrag import LightRAG as LightRAG, QueryParam as QueryParam
2
 
3
+ __version__ = "0.0.8"
4
  __author__ = "Zirui Guo"
5
  __url__ = "https://github.com/HKUDS/LightRAG"
lightrag/lightrag.py CHANGED
@@ -109,6 +109,7 @@ class LightRAG:
109
  llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
110
  llm_model_max_token_size: int = 32768
111
  llm_model_max_async: int = 16
 
112
 
113
  # storage
114
  key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
@@ -179,7 +180,11 @@ class LightRAG:
179
  )
180
 
181
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
182
- partial(self.llm_model_func, hashing_kv=self.llm_response_cache)
 
 
 
 
183
  )
184
  def _get_storage_class(self) -> Type[BaseGraphStorage]:
185
  return {
@@ -239,7 +244,7 @@ class LightRAG:
239
  logger.info("[Entity Extraction]...")
240
  maybe_new_kg = await extract_entities(
241
  inserting_chunks,
242
- knwoledge_graph_inst=self.chunk_entity_relation_graph,
243
  entity_vdb=self.entities_vdb,
244
  relationships_vdb=self.relationships_vdb,
245
  global_config=asdict(self),
 
109
  llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct" #'meta-llama/Llama-3.2-1B'#'google/gemma-2-2b-it'
110
  llm_model_max_token_size: int = 32768
111
  llm_model_max_async: int = 16
112
+ llm_model_kwargs: dict = field(default_factory=dict)
113
 
114
  # storage
115
  key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
 
180
  )
181
 
182
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
183
+ partial(
184
+ self.llm_model_func,
185
+ hashing_kv=self.llm_response_cache,
186
+ **self.llm_model_kwargs,
187
+ )
188
  )
189
  def _get_storage_class(self) -> Type[BaseGraphStorage]:
190
  return {
 
244
  logger.info("[Entity Extraction]...")
245
  maybe_new_kg = await extract_entities(
246
  inserting_chunks,
247
+ knowledge_graph_inst=self.chunk_entity_relation_graph,
248
  entity_vdb=self.entities_vdb,
249
  relationships_vdb=self.relationships_vdb,
250
  global_config=asdict(self),
lightrag/llm.py CHANGED
@@ -7,7 +7,13 @@ import aiohttp
7
  import numpy as np
8
  import ollama
9
 
10
- from openai import AsyncOpenAI, APIConnectionError, RateLimitError, Timeout, AsyncAzureOpenAI
 
 
 
 
 
 
11
 
12
  import base64
13
  import struct
@@ -70,26 +76,31 @@ async def openai_complete_if_cache(
70
  )
71
  return response.choices[0].message.content
72
 
 
73
  @retry(
74
  stop=stop_after_attempt(3),
75
  wait=wait_exponential(multiplier=1, min=4, max=10),
76
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
77
  )
78
- async def azure_openai_complete_if_cache(model,
 
79
  prompt,
80
  system_prompt=None,
81
  history_messages=[],
82
  base_url=None,
83
  api_key=None,
84
- **kwargs):
 
85
  if api_key:
86
  os.environ["AZURE_OPENAI_API_KEY"] = api_key
87
  if base_url:
88
  os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
89
 
90
- openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
91
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
92
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
 
 
93
 
94
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
95
  messages = []
@@ -114,6 +125,7 @@ async def azure_openai_complete_if_cache(model,
114
  )
115
  return response.choices[0].message.content
116
 
 
117
  class BedrockError(Exception):
118
  """Generic error for issues related to Amazon Bedrock"""
119
 
@@ -205,8 +217,12 @@ async def bedrock_complete_if_cache(
205
 
206
  @lru_cache(maxsize=1)
207
  def initialize_hf_model(model_name):
208
- hf_tokenizer = AutoTokenizer.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
209
- hf_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", trust_remote_code=True)
 
 
 
 
210
  if hf_tokenizer.pad_token is None:
211
  hf_tokenizer.pad_token = hf_tokenizer.eos_token
212
 
@@ -266,10 +282,13 @@ async def hf_model_if_cache(
266
  input_ids = hf_tokenizer(
267
  input_prompt, return_tensors="pt", padding=True, truncation=True
268
  ).to("cuda")
 
269
  output = hf_model.generate(
270
- **input_ids, max_new_tokens=200, num_return_sequences=1, early_stopping=True
 
 
 
271
  )
272
- response_text = hf_tokenizer.decode(output[0], skip_special_tokens=True)
273
  if hashing_kv is not None:
274
  await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
275
  return response_text
@@ -280,8 +299,10 @@ async def ollama_model_if_cache(
280
  ) -> str:
281
  kwargs.pop("max_tokens", None)
282
  kwargs.pop("response_format", None)
 
 
283
 
284
- ollama_client = ollama.AsyncClient()
285
  messages = []
286
  if system_prompt:
287
  messages.append({"role": "system", "content": system_prompt})
@@ -305,6 +326,135 @@ async def ollama_model_if_cache(
305
  return result
306
 
307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  async def gpt_4o_complete(
309
  prompt, system_prompt=None, history_messages=[], **kwargs
310
  ) -> str:
@@ -328,8 +478,9 @@ async def gpt_4o_mini_complete(
328
  **kwargs,
329
  )
330
 
 
331
  async def azure_openai_complete(
332
- prompt, system_prompt=None, history_messages=[], **kwargs
333
  ) -> str:
334
  return await azure_openai_complete_if_cache(
335
  "conversation-4o-mini",
@@ -339,6 +490,7 @@ async def azure_openai_complete(
339
  **kwargs,
340
  )
341
 
 
342
  async def bedrock_complete(
343
  prompt, system_prompt=None, history_messages=[], **kwargs
344
  ) -> str:
@@ -418,9 +570,11 @@ async def azure_openai_embedding(
418
  if base_url:
419
  os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
420
 
421
- openai_async_client = AsyncAzureOpenAI(azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
422
- api_key=os.getenv("AZURE_OPENAI_API_KEY"),
423
- api_version=os.getenv("AZURE_OPENAI_API_VERSION"))
 
 
424
 
425
  response = await openai_async_client.embeddings.create(
426
  model=model, input=texts, encoding_format="float"
@@ -440,35 +594,28 @@ async def siliconcloud_embedding(
440
  max_token_size: int = 512,
441
  api_key: str = None,
442
  ) -> np.ndarray:
443
- if api_key and not api_key.startswith('Bearer '):
444
- api_key = 'Bearer ' + api_key
445
 
446
- headers = {
447
- "Authorization": api_key,
448
- "Content-Type": "application/json"
449
- }
450
 
451
  truncate_texts = [text[0:max_token_size] for text in texts]
452
 
453
- payload = {
454
- "model": model,
455
- "input": truncate_texts,
456
- "encoding_format": "base64"
457
- }
458
 
459
  base64_strings = []
460
  async with aiohttp.ClientSession() as session:
461
  async with session.post(base_url, headers=headers, json=payload) as response:
462
  content = await response.json()
463
- if 'code' in content:
464
  raise ValueError(content)
465
- base64_strings = [item['embedding'] for item in content['data']]
466
-
467
  embeddings = []
468
  for string in base64_strings:
469
  decode_bytes = base64.b64decode(string)
470
  n = len(decode_bytes) // 4
471
- float_array = struct.unpack('<' + 'f' * n, decode_bytes)
472
  embeddings.append(float_array)
473
  return np.array(embeddings)
474
 
@@ -555,14 +702,16 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
555
  return embeddings.detach().numpy()
556
 
557
 
558
- async def ollama_embedding(texts: list[str], embed_model) -> np.ndarray:
559
  embed_text = []
 
560
  for text in texts:
561
- data = ollama.embeddings(model=embed_model, prompt=text)
562
  embed_text.append(data["embedding"])
563
 
564
  return embed_text
565
 
 
566
  class Model(BaseModel):
567
  """
568
  This is a Pydantic model class named 'Model' that is used to define a custom language model.
@@ -580,14 +729,20 @@ class Model(BaseModel):
580
  The 'kwargs' dictionary contains the model name and API key to be passed to the function.
581
  """
582
 
583
- gen_func: Callable[[Any], str] = Field(..., description="A function that generates the response from the llm. The response must be a string")
584
- kwargs: Dict[str, Any] = Field(..., description="The arguments to pass to the callable function. Eg. the api key, model name, etc")
 
 
 
 
 
 
585
 
586
  class Config:
587
  arbitrary_types_allowed = True
588
 
589
 
590
- class MultiModel():
591
  """
592
  Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
593
  Could also be used for spliting across diffrent models or providers.
@@ -611,26 +766,31 @@ class MultiModel():
611
  )
612
  ```
613
  """
 
614
  def __init__(self, models: List[Model]):
615
  self._models = models
616
  self._current_model = 0
617
-
618
  def _next_model(self):
619
  self._current_model = (self._current_model + 1) % len(self._models)
620
  return self._models[self._current_model]
621
 
622
  async def llm_model_func(
623
- self,
624
- prompt, system_prompt=None, history_messages=[], **kwargs
625
  ) -> str:
626
- kwargs.pop("model", None) # stop from overwriting the custom model name
627
  next_model = self._next_model()
628
- args = dict(prompt=prompt, system_prompt=system_prompt, history_messages=history_messages, **kwargs, **next_model.kwargs)
629
-
630
- return await next_model.gen_func(
631
- **args
 
 
632
  )
633
 
 
 
 
634
  if __name__ == "__main__":
635
  import asyncio
636
 
 
7
  import numpy as np
8
  import ollama
9
 
10
+ from openai import (
11
+ AsyncOpenAI,
12
+ APIConnectionError,
13
+ RateLimitError,
14
+ Timeout,
15
+ AsyncAzureOpenAI,
16
+ )
17
 
18
  import base64
19
  import struct
 
76
  )
77
  return response.choices[0].message.content
78
 
79
+
80
  @retry(
81
  stop=stop_after_attempt(3),
82
  wait=wait_exponential(multiplier=1, min=4, max=10),
83
  retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
84
  )
85
+ async def azure_openai_complete_if_cache(
86
+ model,
87
  prompt,
88
  system_prompt=None,
89
  history_messages=[],
90
  base_url=None,
91
  api_key=None,
92
+ **kwargs,
93
+ ):
94
  if api_key:
95
  os.environ["AZURE_OPENAI_API_KEY"] = api_key
96
  if base_url:
97
  os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
98
 
99
+ openai_async_client = AsyncAzureOpenAI(
100
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
101
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
102
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
103
+ )
104
 
105
  hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
106
  messages = []
 
125
  )
126
  return response.choices[0].message.content
127
 
128
+
129
  class BedrockError(Exception):
130
  """Generic error for issues related to Amazon Bedrock"""
131
 
 
217
 
218
  @lru_cache(maxsize=1)
219
  def initialize_hf_model(model_name):
220
+ hf_tokenizer = AutoTokenizer.from_pretrained(
221
+ model_name, device_map="auto", trust_remote_code=True
222
+ )
223
+ hf_model = AutoModelForCausalLM.from_pretrained(
224
+ model_name, device_map="auto", trust_remote_code=True
225
+ )
226
  if hf_tokenizer.pad_token is None:
227
  hf_tokenizer.pad_token = hf_tokenizer.eos_token
228
 
 
282
  input_ids = hf_tokenizer(
283
  input_prompt, return_tensors="pt", padding=True, truncation=True
284
  ).to("cuda")
285
+ inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
286
  output = hf_model.generate(
287
+ **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
288
+ )
289
+ response_text = hf_tokenizer.decode(
290
+ output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
291
  )
 
292
  if hashing_kv is not None:
293
  await hashing_kv.upsert({args_hash: {"return": response_text, "model": model}})
294
  return response_text
 
299
  ) -> str:
300
  kwargs.pop("max_tokens", None)
301
  kwargs.pop("response_format", None)
302
+ host = kwargs.pop("host", None)
303
+ timeout = kwargs.pop("timeout", None)
304
 
305
+ ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
306
  messages = []
307
  if system_prompt:
308
  messages.append({"role": "system", "content": system_prompt})
 
326
  return result
327
 
328
 
329
+ @lru_cache(maxsize=1)
330
+ def initialize_lmdeploy_pipeline(
331
+ model,
332
+ tp=1,
333
+ chat_template=None,
334
+ log_level="WARNING",
335
+ model_format="hf",
336
+ quant_policy=0,
337
+ ):
338
+ from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
339
+
340
+ lmdeploy_pipe = pipeline(
341
+ model_path=model,
342
+ backend_config=TurbomindEngineConfig(
343
+ tp=tp, model_format=model_format, quant_policy=quant_policy
344
+ ),
345
+ chat_template_config=ChatTemplateConfig(model_name=chat_template)
346
+ if chat_template
347
+ else None,
348
+ log_level="WARNING",
349
+ )
350
+ return lmdeploy_pipe
351
+
352
+
353
+ async def lmdeploy_model_if_cache(
354
+ model,
355
+ prompt,
356
+ system_prompt=None,
357
+ history_messages=[],
358
+ chat_template=None,
359
+ model_format="hf",
360
+ quant_policy=0,
361
+ **kwargs,
362
+ ) -> str:
363
+ """
364
+ Args:
365
+ model (str): The path to the model.
366
+ It could be one of the following options:
367
+ - i) A local directory path of a turbomind model which is
368
+ converted by `lmdeploy convert` command or download
369
+ from ii) and iii).
370
+ - ii) The model_id of a lmdeploy-quantized model hosted
371
+ inside a model repo on huggingface.co, such as
372
+ "InternLM/internlm-chat-20b-4bit",
373
+ "lmdeploy/llama2-chat-70b-4bit", etc.
374
+ - iii) The model_id of a model hosted inside a model repo
375
+ on huggingface.co, such as "internlm/internlm-chat-7b",
376
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
377
+ and so on.
378
+ chat_template (str): needed when model is a pytorch model on
379
+ huggingface.co, such as "internlm-chat-7b",
380
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
381
+ and when the model name of local path did not match the original model name in HF.
382
+ tp (int): tensor parallel
383
+ prompt (Union[str, List[str]]): input texts to be completed.
384
+ do_preprocess (bool): whether pre-process the messages. Default to
385
+ True, which means chat_template will be applied.
386
+ skip_special_tokens (bool): Whether or not to remove special tokens
387
+ in the decoding. Default to be True.
388
+ do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
389
+ Default to be False, which means greedy decoding will be applied.
390
+ """
391
+ try:
392
+ import lmdeploy
393
+ from lmdeploy import version_info, GenerationConfig
394
+ except Exception:
395
+ raise ImportError("Please install lmdeploy before intialize lmdeploy backend.")
396
+
397
+ kwargs.pop("response_format", None)
398
+ max_new_tokens = kwargs.pop("max_tokens", 512)
399
+ tp = kwargs.pop("tp", 1)
400
+ skip_special_tokens = kwargs.pop("skip_special_tokens", True)
401
+ do_preprocess = kwargs.pop("do_preprocess", True)
402
+ do_sample = kwargs.pop("do_sample", False)
403
+ gen_params = kwargs
404
+
405
+ version = version_info
406
+ if do_sample is not None and version < (0, 6, 0):
407
+ raise RuntimeError(
408
+ "`do_sample` parameter is not supported by lmdeploy until "
409
+ f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
410
+ )
411
+ else:
412
+ do_sample = True
413
+ gen_params.update(do_sample=do_sample)
414
+
415
+ lmdeploy_pipe = initialize_lmdeploy_pipeline(
416
+ model=model,
417
+ tp=tp,
418
+ chat_template=chat_template,
419
+ model_format=model_format,
420
+ quant_policy=quant_policy,
421
+ log_level="WARNING",
422
+ )
423
+
424
+ messages = []
425
+ if system_prompt:
426
+ messages.append({"role": "system", "content": system_prompt})
427
+
428
+ hashing_kv: BaseKVStorage = kwargs.pop("hashing_kv", None)
429
+ messages.extend(history_messages)
430
+ messages.append({"role": "user", "content": prompt})
431
+ if hashing_kv is not None:
432
+ args_hash = compute_args_hash(model, messages)
433
+ if_cache_return = await hashing_kv.get_by_id(args_hash)
434
+ if if_cache_return is not None:
435
+ return if_cache_return["return"]
436
+
437
+ gen_config = GenerationConfig(
438
+ skip_special_tokens=skip_special_tokens,
439
+ max_new_tokens=max_new_tokens,
440
+ **gen_params,
441
+ )
442
+
443
+ response = ""
444
+ async for res in lmdeploy_pipe.generate(
445
+ messages,
446
+ gen_config=gen_config,
447
+ do_preprocess=do_preprocess,
448
+ stream_response=False,
449
+ session_id=1,
450
+ ):
451
+ response += res.response
452
+
453
+ if hashing_kv is not None:
454
+ await hashing_kv.upsert({args_hash: {"return": response, "model": model}})
455
+ return response
456
+
457
+
458
  async def gpt_4o_complete(
459
  prompt, system_prompt=None, history_messages=[], **kwargs
460
  ) -> str:
 
478
  **kwargs,
479
  )
480
 
481
+
482
  async def azure_openai_complete(
483
+ prompt, system_prompt=None, history_messages=[], **kwargs
484
  ) -> str:
485
  return await azure_openai_complete_if_cache(
486
  "conversation-4o-mini",
 
490
  **kwargs,
491
  )
492
 
493
+
494
  async def bedrock_complete(
495
  prompt, system_prompt=None, history_messages=[], **kwargs
496
  ) -> str:
 
570
  if base_url:
571
  os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
572
 
573
+ openai_async_client = AsyncAzureOpenAI(
574
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
575
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
576
+ api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
577
+ )
578
 
579
  response = await openai_async_client.embeddings.create(
580
  model=model, input=texts, encoding_format="float"
 
594
  max_token_size: int = 512,
595
  api_key: str = None,
596
  ) -> np.ndarray:
597
+ if api_key and not api_key.startswith("Bearer "):
598
+ api_key = "Bearer " + api_key
599
 
600
+ headers = {"Authorization": api_key, "Content-Type": "application/json"}
 
 
 
601
 
602
  truncate_texts = [text[0:max_token_size] for text in texts]
603
 
604
+ payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
 
 
 
 
605
 
606
  base64_strings = []
607
  async with aiohttp.ClientSession() as session:
608
  async with session.post(base_url, headers=headers, json=payload) as response:
609
  content = await response.json()
610
+ if "code" in content:
611
  raise ValueError(content)
612
+ base64_strings = [item["embedding"] for item in content["data"]]
613
+
614
  embeddings = []
615
  for string in base64_strings:
616
  decode_bytes = base64.b64decode(string)
617
  n = len(decode_bytes) // 4
618
+ float_array = struct.unpack("<" + "f" * n, decode_bytes)
619
  embeddings.append(float_array)
620
  return np.array(embeddings)
621
 
 
702
  return embeddings.detach().numpy()
703
 
704
 
705
+ async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
706
  embed_text = []
707
+ ollama_client = ollama.Client(**kwargs)
708
  for text in texts:
709
+ data = ollama_client.embeddings(model=embed_model, prompt=text)
710
  embed_text.append(data["embedding"])
711
 
712
  return embed_text
713
 
714
+
715
  class Model(BaseModel):
716
  """
717
  This is a Pydantic model class named 'Model' that is used to define a custom language model.
 
729
  The 'kwargs' dictionary contains the model name and API key to be passed to the function.
730
  """
731
 
732
+ gen_func: Callable[[Any], str] = Field(
733
+ ...,
734
+ description="A function that generates the response from the llm. The response must be a string",
735
+ )
736
+ kwargs: Dict[str, Any] = Field(
737
+ ...,
738
+ description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
739
+ )
740
 
741
  class Config:
742
  arbitrary_types_allowed = True
743
 
744
 
745
+ class MultiModel:
746
  """
747
  Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
748
  Could also be used for spliting across diffrent models or providers.
 
766
  )
767
  ```
768
  """
769
+
770
  def __init__(self, models: List[Model]):
771
  self._models = models
772
  self._current_model = 0
773
+
774
  def _next_model(self):
775
  self._current_model = (self._current_model + 1) % len(self._models)
776
  return self._models[self._current_model]
777
 
778
  async def llm_model_func(
779
+ self, prompt, system_prompt=None, history_messages=[], **kwargs
 
780
  ) -> str:
781
+ kwargs.pop("model", None) # stop from overwriting the custom model name
782
  next_model = self._next_model()
783
+ args = dict(
784
+ prompt=prompt,
785
+ system_prompt=system_prompt,
786
+ history_messages=history_messages,
787
+ **kwargs,
788
+ **next_model.kwargs,
789
  )
790
 
791
+ return await next_model.gen_func(**args)
792
+
793
+
794
  if __name__ == "__main__":
795
  import asyncio
796
 
lightrag/operate.py CHANGED
@@ -124,14 +124,14 @@ async def _handle_single_relationship_extraction(
124
  async def _merge_nodes_then_upsert(
125
  entity_name: str,
126
  nodes_data: list[dict],
127
- knwoledge_graph_inst: BaseGraphStorage,
128
  global_config: dict,
129
  ):
130
  already_entitiy_types = []
131
  already_source_ids = []
132
  already_description = []
133
 
134
- already_node = await knwoledge_graph_inst.get_node(entity_name)
135
  if already_node is not None:
136
  already_entitiy_types.append(already_node["entity_type"])
137
  already_source_ids.extend(
@@ -160,7 +160,7 @@ async def _merge_nodes_then_upsert(
160
  description=description,
161
  source_id=source_id,
162
  )
163
- await knwoledge_graph_inst.upsert_node(
164
  entity_name,
165
  node_data=node_data,
166
  )
@@ -172,7 +172,7 @@ async def _merge_edges_then_upsert(
172
  src_id: str,
173
  tgt_id: str,
174
  edges_data: list[dict],
175
- knwoledge_graph_inst: BaseGraphStorage,
176
  global_config: dict,
177
  ):
178
  already_weights = []
@@ -180,8 +180,8 @@ async def _merge_edges_then_upsert(
180
  already_description = []
181
  already_keywords = []
182
 
183
- if await knwoledge_graph_inst.has_edge(src_id, tgt_id):
184
- already_edge = await knwoledge_graph_inst.get_edge(src_id, tgt_id)
185
  already_weights.append(already_edge["weight"])
186
  already_source_ids.extend(
187
  split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
@@ -202,8 +202,8 @@ async def _merge_edges_then_upsert(
202
  set([dp["source_id"] for dp in edges_data] + already_source_ids)
203
  )
204
  for need_insert_id in [src_id, tgt_id]:
205
- if not (await knwoledge_graph_inst.has_node(need_insert_id)):
206
- await knwoledge_graph_inst.upsert_node(
207
  need_insert_id,
208
  node_data={
209
  "source_id": source_id,
@@ -214,7 +214,7 @@ async def _merge_edges_then_upsert(
214
  description = await _handle_entity_relation_summary(
215
  (src_id, tgt_id), description, global_config
216
  )
217
- await knwoledge_graph_inst.upsert_edge(
218
  src_id,
219
  tgt_id,
220
  edge_data=dict(
@@ -237,7 +237,7 @@ async def _merge_edges_then_upsert(
237
 
238
  async def extract_entities(
239
  chunks: dict[str, TextChunkSchema],
240
- knwoledge_graph_inst: BaseGraphStorage,
241
  entity_vdb: BaseVectorStorage,
242
  relationships_vdb: BaseVectorStorage,
243
  global_config: dict,
@@ -341,13 +341,13 @@ async def extract_entities(
341
  maybe_edges[tuple(sorted(k))].extend(v)
342
  all_entities_data = await asyncio.gather(
343
  *[
344
- _merge_nodes_then_upsert(k, v, knwoledge_graph_inst, global_config)
345
  for k, v in maybe_nodes.items()
346
  ]
347
  )
348
  all_relationships_data = await asyncio.gather(
349
  *[
350
- _merge_edges_then_upsert(k[0], k[1], v, knwoledge_graph_inst, global_config)
351
  for k, v in maybe_edges.items()
352
  ]
353
  )
@@ -384,7 +384,7 @@ async def extract_entities(
384
  }
385
  await relationships_vdb.upsert(data_for_vdb)
386
 
387
- return knwoledge_graph_inst
388
 
389
 
390
  async def local_query(
 
124
  async def _merge_nodes_then_upsert(
125
  entity_name: str,
126
  nodes_data: list[dict],
127
+ knowledge_graph_inst: BaseGraphStorage,
128
  global_config: dict,
129
  ):
130
  already_entitiy_types = []
131
  already_source_ids = []
132
  already_description = []
133
 
134
+ already_node = await knowledge_graph_inst.get_node(entity_name)
135
  if already_node is not None:
136
  already_entitiy_types.append(already_node["entity_type"])
137
  already_source_ids.extend(
 
160
  description=description,
161
  source_id=source_id,
162
  )
163
+ await knowledge_graph_inst.upsert_node(
164
  entity_name,
165
  node_data=node_data,
166
  )
 
172
  src_id: str,
173
  tgt_id: str,
174
  edges_data: list[dict],
175
+ knowledge_graph_inst: BaseGraphStorage,
176
  global_config: dict,
177
  ):
178
  already_weights = []
 
180
  already_description = []
181
  already_keywords = []
182
 
183
+ if await knowledge_graph_inst.has_edge(src_id, tgt_id):
184
+ already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
185
  already_weights.append(already_edge["weight"])
186
  already_source_ids.extend(
187
  split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
 
202
  set([dp["source_id"] for dp in edges_data] + already_source_ids)
203
  )
204
  for need_insert_id in [src_id, tgt_id]:
205
+ if not (await knowledge_graph_inst.has_node(need_insert_id)):
206
+ await knowledge_graph_inst.upsert_node(
207
  need_insert_id,
208
  node_data={
209
  "source_id": source_id,
 
214
  description = await _handle_entity_relation_summary(
215
  (src_id, tgt_id), description, global_config
216
  )
217
+ await knowledge_graph_inst.upsert_edge(
218
  src_id,
219
  tgt_id,
220
  edge_data=dict(
 
237
 
238
  async def extract_entities(
239
  chunks: dict[str, TextChunkSchema],
240
+ knowledge_graph_inst: BaseGraphStorage,
241
  entity_vdb: BaseVectorStorage,
242
  relationships_vdb: BaseVectorStorage,
243
  global_config: dict,
 
341
  maybe_edges[tuple(sorted(k))].extend(v)
342
  all_entities_data = await asyncio.gather(
343
  *[
344
+ _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
345
  for k, v in maybe_nodes.items()
346
  ]
347
  )
348
  all_relationships_data = await asyncio.gather(
349
  *[
350
+ _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
351
  for k, v in maybe_edges.items()
352
  ]
353
  )
 
384
  }
385
  await relationships_vdb.upsert(data_for_vdb)
386
 
387
+ return knowledge_graph_inst
388
 
389
 
390
  async def local_query(
lightrag/utils.py CHANGED
@@ -185,6 +185,7 @@ def save_data_to_file(data, file_name):
185
  with open(file_name, "w", encoding="utf-8") as f:
186
  json.dump(data, f, ensure_ascii=False, indent=4)
187
 
 
188
  def xml_to_json(xml_file):
189
  try:
190
  tree = ET.parse(xml_file)
@@ -194,31 +195,42 @@ def xml_to_json(xml_file):
194
  print(f"Root element: {root.tag}")
195
  print(f"Root attributes: {root.attrib}")
196
 
197
- data = {
198
- "nodes": [],
199
- "edges": []
200
- }
201
 
202
  # Use namespace
203
- namespace = {'': 'http://graphml.graphdrawing.org/xmlns'}
204
 
205
- for node in root.findall('.//node', namespace):
206
  node_data = {
207
- "id": node.get('id').strip('"'),
208
- "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"') if node.find("./data[@key='d0']", namespace) is not None else "",
209
- "description": node.find("./data[@key='d1']", namespace).text if node.find("./data[@key='d1']", namespace) is not None else "",
210
- "source_id": node.find("./data[@key='d2']", namespace).text if node.find("./data[@key='d2']", namespace) is not None else ""
 
 
 
 
 
 
211
  }
212
  data["nodes"].append(node_data)
213
 
214
- for edge in root.findall('.//edge', namespace):
215
  edge_data = {
216
- "source": edge.get('source').strip('"'),
217
- "target": edge.get('target').strip('"'),
218
- "weight": float(edge.find("./data[@key='d3']", namespace).text) if edge.find("./data[@key='d3']", namespace) is not None else 0.0,
219
- "description": edge.find("./data[@key='d4']", namespace).text if edge.find("./data[@key='d4']", namespace) is not None else "",
220
- "keywords": edge.find("./data[@key='d5']", namespace).text if edge.find("./data[@key='d5']", namespace) is not None else "",
221
- "source_id": edge.find("./data[@key='d6']", namespace).text if edge.find("./data[@key='d6']", namespace) is not None else ""
 
 
 
 
 
 
 
 
222
  }
223
  data["edges"].append(edge_data)
224
 
 
185
  with open(file_name, "w", encoding="utf-8") as f:
186
  json.dump(data, f, ensure_ascii=False, indent=4)
187
 
188
+
189
  def xml_to_json(xml_file):
190
  try:
191
  tree = ET.parse(xml_file)
 
195
  print(f"Root element: {root.tag}")
196
  print(f"Root attributes: {root.attrib}")
197
 
198
+ data = {"nodes": [], "edges": []}
 
 
 
199
 
200
  # Use namespace
201
+ namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
202
 
203
+ for node in root.findall(".//node", namespace):
204
  node_data = {
205
+ "id": node.get("id").strip('"'),
206
+ "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
207
+ if node.find("./data[@key='d0']", namespace) is not None
208
+ else "",
209
+ "description": node.find("./data[@key='d1']", namespace).text
210
+ if node.find("./data[@key='d1']", namespace) is not None
211
+ else "",
212
+ "source_id": node.find("./data[@key='d2']", namespace).text
213
+ if node.find("./data[@key='d2']", namespace) is not None
214
+ else "",
215
  }
216
  data["nodes"].append(node_data)
217
 
218
+ for edge in root.findall(".//edge", namespace):
219
  edge_data = {
220
+ "source": edge.get("source").strip('"'),
221
+ "target": edge.get("target").strip('"'),
222
+ "weight": float(edge.find("./data[@key='d3']", namespace).text)
223
+ if edge.find("./data[@key='d3']", namespace) is not None
224
+ else 0.0,
225
+ "description": edge.find("./data[@key='d4']", namespace).text
226
+ if edge.find("./data[@key='d4']", namespace) is not None
227
+ else "",
228
+ "keywords": edge.find("./data[@key='d5']", namespace).text
229
+ if edge.find("./data[@key='d5']", namespace) is not None
230
+ else "",
231
+ "source_id": edge.find("./data[@key='d6']", namespace).text
232
+ if edge.find("./data[@key='d6']", namespace) is not None
233
+ else "",
234
  }
235
  data["edges"].append(edge_data)
236
 
reproduce/Step_3.py CHANGED
@@ -18,8 +18,8 @@ def extract_queries(file_path):
18
 
19
  async def process_query(query_text, rag_instance, query_param):
20
  try:
21
- result, context = await rag_instance.aquery(query_text, param=query_param)
22
- return {"query": query_text, "result": result, "context": context}, None
23
  except Exception as e:
24
  return None, {"query": query_text, "error": str(e)}
25
 
 
18
 
19
  async def process_query(query_text, rag_instance, query_param):
20
  try:
21
+ result = await rag_instance.aquery(query_text, param=query_param)
22
+ return {"query": query_text, "result": result}, None
23
  except Exception as e:
24
  return None, {"query": query_text, "error": str(e)}
25
 
reproduce/Step_3_openai_compatible.py CHANGED
@@ -50,8 +50,8 @@ def extract_queries(file_path):
50
 
51
  async def process_query(query_text, rag_instance, query_param):
52
  try:
53
- result, context = await rag_instance.aquery(query_text, param=query_param)
54
- return {"query": query_text, "result": result, "context": context}, None
55
  except Exception as e:
56
  return None, {"query": query_text, "error": str(e)}
57
 
 
50
 
51
  async def process_query(query_text, rag_instance, query_param):
52
  try:
53
+ result = await rag_instance.aquery(query_text, param=query_param)
54
+ return {"query": query_text, "result": result}, None
55
  except Exception as e:
56
  return None, {"query": query_text, "error": str(e)}
57
 
requirements.txt CHANGED
@@ -1,16 +1,17 @@
1
  accelerate
2
  aioboto3
 
3
  graspologic
4
  hnswlib
5
  nano-vectordb
 
6
  networkx
7
  ollama
8
  openai
 
9
  tenacity
10
  tiktoken
11
  torch
12
  transformers
13
  xxhash
14
- pyvis
15
- aiohttp
16
- neo4j
 
1
  accelerate
2
  aioboto3
3
+ aiohttp
4
  graspologic
5
  hnswlib
6
  nano-vectordb
7
+ neo4j
8
  networkx
9
  ollama
10
  openai
11
+ pyvis
12
  tenacity
13
  tiktoken
14
  torch
15
  transformers
16
  xxhash
17
+ # lmdeploy[all]
 
 
setup.py CHANGED
@@ -1,39 +1,88 @@
1
  import setuptools
 
2
 
3
- with open("README.md", "r", encoding="utf-8") as fh:
4
- long_description = fh.read()
5
 
 
 
 
 
 
 
6
 
7
- vars2find = ["__author__", "__version__", "__url__"]
8
- vars2readme = {}
9
- with open("./lightrag/__init__.py") as f:
10
- for line in f.readlines():
11
- for v in vars2find:
12
- if line.startswith(v):
13
- line = line.replace(" ", "").replace('"', "").replace("'", "").strip()
14
- vars2readme[v] = line.split("=")[1]
15
 
16
- deps = []
17
- with open("./requirements.txt") as f:
18
- for line in f.readlines():
19
- if not line.strip():
20
- continue
21
- deps.append(line.strip())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  setuptools.setup(
24
  name="lightrag-hku",
25
- url=vars2readme["__url__"],
26
- version=vars2readme["__version__"],
27
- author=vars2readme["__author__"],
28
  description="LightRAG: Simple and Fast Retrieval-Augmented Generation",
29
  long_description=long_description,
30
  long_description_content_type="text/markdown",
31
- packages=["lightrag"],
 
 
32
  classifiers=[
 
33
  "Programming Language :: Python :: 3",
34
  "License :: OSI Approved :: MIT License",
35
  "Operating System :: OS Independent",
 
 
36
  ],
37
  python_requires=">=3.9",
38
- install_requires=deps,
 
 
 
 
 
 
 
 
39
  )
 
1
  import setuptools
2
+ from pathlib import Path
3
 
 
 
4
 
5
+ # Reading the long description from README.md
6
+ def read_long_description():
7
+ try:
8
+ return Path("README.md").read_text(encoding="utf-8")
9
+ except FileNotFoundError:
10
+ return "A description of LightRAG is currently unavailable."
11
 
 
 
 
 
 
 
 
 
12
 
13
+ # Retrieving metadata from __init__.py
14
+ def retrieve_metadata():
15
+ vars2find = ["__author__", "__version__", "__url__"]
16
+ vars2readme = {}
17
+ try:
18
+ with open("./lightrag/__init__.py") as f:
19
+ for line in f.readlines():
20
+ for v in vars2find:
21
+ if line.startswith(v):
22
+ line = (
23
+ line.replace(" ", "")
24
+ .replace('"', "")
25
+ .replace("'", "")
26
+ .strip()
27
+ )
28
+ vars2readme[v] = line.split("=")[1]
29
+ except FileNotFoundError:
30
+ raise FileNotFoundError("Metadata file './lightrag/__init__.py' not found.")
31
+
32
+ # Checking if all required variables are found
33
+ missing_vars = [v for v in vars2find if v not in vars2readme]
34
+ if missing_vars:
35
+ raise ValueError(
36
+ f"Missing required metadata variables in __init__.py: {missing_vars}"
37
+ )
38
+
39
+ return vars2readme
40
+
41
+
42
+ # Reading dependencies from requirements.txt
43
+ def read_requirements():
44
+ deps = []
45
+ try:
46
+ with open("./requirements.txt") as f:
47
+ deps = [line.strip() for line in f if line.strip()]
48
+ except FileNotFoundError:
49
+ print(
50
+ "Warning: 'requirements.txt' not found. No dependencies will be installed."
51
+ )
52
+ return deps
53
+
54
+
55
+ metadata = retrieve_metadata()
56
+ long_description = read_long_description()
57
+ requirements = read_requirements()
58
 
59
  setuptools.setup(
60
  name="lightrag-hku",
61
+ url=metadata["__url__"],
62
+ version=metadata["__version__"],
63
+ author=metadata["__author__"],
64
  description="LightRAG: Simple and Fast Retrieval-Augmented Generation",
65
  long_description=long_description,
66
  long_description_content_type="text/markdown",
67
+ packages=setuptools.find_packages(
68
+ exclude=("tests*", "docs*")
69
+ ), # Automatically find packages
70
  classifiers=[
71
+ "Development Status :: 4 - Beta",
72
  "Programming Language :: Python :: 3",
73
  "License :: OSI Approved :: MIT License",
74
  "Operating System :: OS Independent",
75
+ "Intended Audience :: Developers",
76
+ "Topic :: Software Development :: Libraries :: Python Modules",
77
  ],
78
  python_requires=">=3.9",
79
+ install_requires=requirements,
80
+ include_package_data=True, # Includes non-code files from MANIFEST.in
81
+ project_urls={ # Additional project metadata
82
+ "Documentation": metadata.get("__url__", ""),
83
+ "Source": metadata.get("__url__", ""),
84
+ "Tracker": f"{metadata.get('__url__', '')}/issues"
85
+ if metadata.get("__url__")
86
+ else "",
87
+ },
88
  )