zrguo commited on
Commit
ba4763a
·
2 Parent(s): ae5772f aa1c267

Merge pull request #197 from wiltshirek/main

Browse files
.DS_Store ADDED
Binary file (8.2 kB). View file
 
.gitignore CHANGED
@@ -5,4 +5,8 @@ book.txt
5
  lightrag-dev/
6
  .idea/
7
  dist/
 
 
 
 
8
  .venv/
 
5
  lightrag-dev/
6
  .idea/
7
  dist/
8
+ env/
9
+ local_neo4jWorkDir/
10
+ neo4jWorkDir/
11
+ ignore_this.txt
12
  .venv/
Dockerfile ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM debian:bullseye-slim
2
+ ENV JAVA_HOME=/opt/java/openjdk
3
+ COPY --from=eclipse-temurin:17 $JAVA_HOME $JAVA_HOME
4
+ ENV PATH="${JAVA_HOME}/bin:${PATH}" \
5
+ NEO4J_SHA256=7ce97bd9a4348af14df442f00b3dc5085b5983d6f03da643744838c7a1bc8ba7 \
6
+ NEO4J_TARBALL=neo4j-enterprise-5.24.2-unix.tar.gz \
7
+ NEO4J_EDITION=enterprise \
8
+ NEO4J_HOME="/var/lib/neo4j" \
9
+ LANG=C.UTF-8
10
+ ARG NEO4J_URI=https://dist.neo4j.org/neo4j-enterprise-5.24.2-unix.tar.gz
11
+
12
+ RUN addgroup --gid 7474 --system neo4j && adduser --uid 7474 --system --no-create-home --home "${NEO4J_HOME}" --ingroup neo4j neo4j
13
+
14
+ COPY ./local-package/* /startup/
15
+
16
+ RUN apt update \
17
+ && apt-get install -y curl gcc git jq make procps tini wget \
18
+ && curl --fail --silent --show-error --location --remote-name ${NEO4J_URI} \
19
+ && echo "${NEO4J_SHA256} ${NEO4J_TARBALL}" | sha256sum -c --strict --quiet \
20
+ && tar --extract --file ${NEO4J_TARBALL} --directory /var/lib \
21
+ && mv /var/lib/neo4j-* "${NEO4J_HOME}" \
22
+ && rm ${NEO4J_TARBALL} \
23
+ && sed -i 's/Package Type:.*/Package Type: docker bullseye/' $NEO4J_HOME/packaging_info \
24
+ && mv /startup/neo4j-admin-report.sh "${NEO4J_HOME}"/bin/neo4j-admin-report \
25
+ && mv "${NEO4J_HOME}"/data /data \
26
+ && mv "${NEO4J_HOME}"/logs /logs \
27
+ && chown -R neo4j:neo4j /data \
28
+ && chmod -R 777 /data \
29
+ && chown -R neo4j:neo4j /logs \
30
+ && chmod -R 777 /logs \
31
+ && chown -R neo4j:neo4j "${NEO4J_HOME}" \
32
+ && chmod -R 777 "${NEO4J_HOME}" \
33
+ && chmod -R 755 "${NEO4J_HOME}/bin" \
34
+ && ln -s /data "${NEO4J_HOME}"/data \
35
+ && ln -s /logs "${NEO4J_HOME}"/logs \
36
+ && git clone https://github.com/ncopa/su-exec.git \
37
+ && cd su-exec \
38
+ && git checkout 4c3bb42b093f14da70d8ab924b487ccfbb1397af \
39
+ && echo d6c40440609a23483f12eb6295b5191e94baf08298a856bab6e15b10c3b82891 su-exec.c | sha256sum -c \
40
+ && echo 2a87af245eb125aca9305a0b1025525ac80825590800f047419dc57bba36b334 Makefile | sha256sum -c \
41
+ && make \
42
+ && mv /su-exec/su-exec /usr/bin/su-exec \
43
+ && apt-get -y purge --auto-remove curl gcc git make \
44
+ && rm -rf /var/lib/apt/lists/* /su-exec
45
+
46
+
47
+ ENV PATH "${NEO4J_HOME}"/bin:$PATH
48
+
49
+ WORKDIR "${NEO4J_HOME}"
50
+
51
+ VOLUME /data /logs
52
+
53
+ EXPOSE 7474 7473 7687
54
+
55
+ ENTRYPOINT ["tini", "-g", "--", "/startup/docker-entrypoint.sh"]
56
+ CMD ["neo4j"]
README.md CHANGED
@@ -161,6 +161,39 @@ rag = LightRAG(
161
  ```
162
  </details>
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  <details>
165
  <summary> Using Ollama Models </summary>
166
 
 
161
  ```
162
  </details>
163
 
164
+
165
+ <details>
166
+ <summary> Using Neo4J for Storage </summary>
167
+
168
+ * For production level scenarios you will most likely want to leverage an enterprise solution
169
+ * for KG storage. Running Neo4J in Docker is recommended for seamless local testing.
170
+ * See: https://hub.docker.com/_/neo4j
171
+
172
+
173
+ ```python
174
+ export NEO4J_URI="neo4j://localhost:7687"
175
+ export NEO4J_USERNAME="neo4j"
176
+ export NEO4J_PASSWORD="password"
177
+
178
+ When you launch the project be sure to override the default KG: NetworkS
179
+ by specifying kg="Neo4JStorage".
180
+
181
+ # Note: Default settings use NetworkX
182
+ #Initialize LightRAG with Neo4J implementation.
183
+ WORKING_DIR = "./local_neo4jWorkDir"
184
+
185
+ rag = LightRAG(
186
+ working_dir=WORKING_DIR,
187
+ llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
188
+ kg="Neo4JStorage", #<-----------override KG default
189
+ log_level="DEBUG" #<-----------override log_level default
190
+ )
191
+ ```
192
+ see test_neo4j.py for a working example.
193
+ </details>
194
+
195
+
196
+
197
  <details>
198
  <summary> Using Ollama Models </summary>
199
 
get_all_edges_nx.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import networkx as nx
2
+
3
+ G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
4
+
5
+ def get_all_edges_and_nodes(G):
6
+ # Get all edges and their properties
7
+ edges_with_properties = []
8
+ for u, v, data in G.edges(data=True):
9
+ edges_with_properties.append({
10
+ 'start': u,
11
+ 'end': v,
12
+ 'label': data.get('label', ''), # Assuming 'label' is used for edge type
13
+ 'properties': data,
14
+ 'start_node_properties': G.nodes[u],
15
+ 'end_node_properties': G.nodes[v]
16
+ })
17
+
18
+ return edges_with_properties
19
+
20
+ # Example usage
21
+ if __name__ == "__main__":
22
+ # Assume G is your NetworkX graph loaded from Neo4j
23
+
24
+ all_edges = get_all_edges_and_nodes(G)
25
+
26
+ # Print all edges and node properties
27
+ for edge in all_edges:
28
+ print(f"Edge Label: {edge['label']}")
29
+ print(f"Edge Properties: {edge['properties']}")
30
+ print(f"Start Node: {edge['start']}")
31
+ print(f"Start Node Properties: {edge['start_node_properties']}")
32
+ print(f"End Node: {edge['end']}")
33
+ print(f"End Node Properties: {edge['end_node_properties']}")
34
+ print("---")
graph_chunk_entity_relation.gefx ADDED
The diff for this file is too large to render. See raw diff
 
lightrag/kg/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # print ("init package vars here. ......")
2
+
3
+
lightrag/kg/neo4j_impl.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import html
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Any, Union, cast, Tuple, List, Dict
6
+ import numpy as np
7
+ import inspect
8
+ from lightrag.utils import load_json, logger, write_json
9
+ from ..base import (
10
+ BaseGraphStorage
11
+ )
12
+ from neo4j import AsyncGraphDatabase,exceptions as neo4jExceptions,AsyncDriver,AsyncSession, AsyncManagedTransaction
13
+
14
+ from contextlib import asynccontextmanager
15
+
16
+
17
+ from tenacity import (
18
+ retry,
19
+ stop_after_attempt,
20
+ wait_exponential,
21
+ retry_if_exception_type,
22
+ )
23
+
24
+
25
+ @dataclass
26
+ class Neo4JStorage(BaseGraphStorage):
27
+ @staticmethod
28
+ def load_nx_graph(file_name):
29
+ print ("no preloading of graph with neo4j in production")
30
+
31
+ def __init__(self, namespace, global_config):
32
+ super().__init__(namespace=namespace, global_config=global_config)
33
+ self._driver = None
34
+ self._driver_lock = asyncio.Lock()
35
+ URI = os.environ["NEO4J_URI"]
36
+ USERNAME = os.environ["NEO4J_USERNAME"]
37
+ PASSWORD = os.environ["NEO4J_PASSWORD"]
38
+ self._driver: AsyncDriver = AsyncGraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
39
+ return None
40
+
41
+ def __post_init__(self):
42
+ self._node_embed_algorithms = {
43
+ "node2vec": self._node2vec_embed,
44
+ }
45
+
46
+
47
+ async def close(self):
48
+ if self._driver:
49
+ await self._driver.close()
50
+ self._driver = None
51
+
52
+
53
+
54
+ async def __aexit__(self, exc_type, exc, tb):
55
+ if self._driver:
56
+ await self._driver.close()
57
+
58
+ async def index_done_callback(self):
59
+ print ("KG successfully indexed.")
60
+
61
+
62
+ async def has_node(self, node_id: str) -> bool:
63
+ entity_name_label = node_id.strip('\"')
64
+
65
+ async with self._driver.session() as session:
66
+ query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
67
+ result = await session.run(query)
68
+ single_result = await result.single()
69
+ logger.debug(
70
+ f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
71
+ )
72
+ return single_result["node_exists"]
73
+
74
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
75
+ entity_name_label_source = source_node_id.strip('\"')
76
+ entity_name_label_target = target_node_id.strip('\"')
77
+
78
+ async with self._driver.session() as session:
79
+ query = (
80
+ f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
81
+ "RETURN COUNT(r) > 0 AS edgeExists"
82
+ )
83
+ result = await session.run(query)
84
+ single_result = await result.single()
85
+ logger.debug(
86
+ f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
87
+ )
88
+ return single_result["edgeExists"]
89
+
90
+ def close(self):
91
+ self._driver.close()
92
+
93
+
94
+
95
+
96
+ async def get_node(self, node_id: str) -> Union[dict, None]:
97
+ async with self._driver.session() as session:
98
+ entity_name_label = node_id.strip('\"')
99
+ query = f"MATCH (n:`{entity_name_label}`) RETURN n"
100
+ result = await session.run(query)
101
+ record = await result.single()
102
+ if record:
103
+ node = record["n"]
104
+ node_dict = dict(node)
105
+ logger.debug(
106
+ f'{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}'
107
+ )
108
+ return node_dict
109
+ return None
110
+
111
+
112
+
113
+ async def node_degree(self, node_id: str) -> int:
114
+ entity_name_label = node_id.strip('\"')
115
+
116
+ async with self._driver.session() as session:
117
+ query = f"""
118
+ MATCH (n:`{entity_name_label}`)
119
+ RETURN COUNT{{ (n)--() }} AS totalEdgeCount
120
+ """
121
+ result = await session.run(query)
122
+ record = await result.single()
123
+ if record:
124
+ edge_count = record["totalEdgeCount"]
125
+ logger.debug(
126
+ f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
127
+ )
128
+ return edge_count
129
+ else:
130
+ return None
131
+
132
+
133
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
134
+ entity_name_label_source = src_id.strip('\"')
135
+ entity_name_label_target = tgt_id.strip('\"')
136
+ src_degree = await self.node_degree(entity_name_label_source)
137
+ trg_degree = await self.node_degree(entity_name_label_target)
138
+
139
+ # Convert None to 0 for addition
140
+ src_degree = 0 if src_degree is None else src_degree
141
+ trg_degree = 0 if trg_degree is None else trg_degree
142
+
143
+ degrees = int(src_degree) + int(trg_degree)
144
+ logger.debug(
145
+ f'{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}'
146
+ )
147
+ return degrees
148
+
149
+
150
+
151
+ async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
152
+ entity_name_label_source = source_node_id.strip('\"')
153
+ entity_name_label_target = target_node_id.strip('\"')
154
+ """
155
+ Find all edges between nodes of two given labels
156
+
157
+ Args:
158
+ source_node_label (str): Label of the source nodes
159
+ target_node_label (str): Label of the target nodes
160
+
161
+ Returns:
162
+ list: List of all relationships/edges found
163
+ """
164
+ async with self._driver.session() as session:
165
+ query = f"""
166
+ MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
167
+ RETURN properties(r) as edge_properties
168
+ LIMIT 1
169
+ """.format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
170
+
171
+ result = await session.run(query)
172
+ record = await result.single()
173
+ if record:
174
+ result = dict(record["edge_properties"])
175
+ logger.debug(
176
+ f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
177
+ )
178
+ return result
179
+ else:
180
+ return None
181
+
182
+
183
+ async def get_node_edges(self, source_node_id: str)-> List[Tuple[str, str]]:
184
+ node_label = source_node_id.strip('\"')
185
+
186
+ """
187
+ Retrieves all edges (relationships) for a particular node identified by its label.
188
+ :return: List of dictionaries containing edge information
189
+ """
190
+ query = f"""MATCH (n:`{node_label}`)
191
+ OPTIONAL MATCH (n)-[r]-(connected)
192
+ RETURN n, r, connected"""
193
+ async with self._driver.session() as session:
194
+ results = await session.run(query)
195
+ edges = []
196
+ async for record in results:
197
+ source_node = record['n']
198
+ connected_node = record['connected']
199
+
200
+ source_label = list(source_node.labels)[0] if source_node.labels else None
201
+ target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
202
+
203
+ if source_label and target_label:
204
+ edges.append((source_label, target_label))
205
+
206
+ return edges
207
+
208
+
209
+ @retry(
210
+ stop=stop_after_attempt(3),
211
+ wait=wait_exponential(multiplier=1, min=4, max=10),
212
+ retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
213
+ )
214
+ async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
215
+ """
216
+ Upsert a node in the Neo4j database.
217
+
218
+ Args:
219
+ node_id: The unique identifier for the node (used as label)
220
+ node_data: Dictionary of node properties
221
+ """
222
+ label = node_id.strip('\"')
223
+ properties = node_data
224
+
225
+ async def _do_upsert(tx: AsyncManagedTransaction):
226
+ query = f"""
227
+ MERGE (n:`{label}`)
228
+ SET n += $properties
229
+ """
230
+ await tx.run(query, properties=properties)
231
+ logger.debug(f"Upserted node with label '{label}' and properties: {properties}")
232
+
233
+ try:
234
+ async with self._driver.session() as session:
235
+ await session.execute_write(_do_upsert)
236
+ except Exception as e:
237
+ logger.error(f"Error during upsert: {str(e)}")
238
+ raise
239
+
240
+ @retry(
241
+ stop=stop_after_attempt(3),
242
+ wait=wait_exponential(multiplier=1, min=4, max=10),
243
+ retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
244
+ )
245
+ async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]):
246
+ """
247
+ Upsert an edge and its properties between two nodes identified by their labels.
248
+
249
+ Args:
250
+ source_node_id (str): Label of the source node (used as identifier)
251
+ target_node_id (str): Label of the target node (used as identifier)
252
+ edge_data (dict): Dictionary of properties to set on the edge
253
+ """
254
+ source_node_label = source_node_id.strip('\"')
255
+ target_node_label = target_node_id.strip('\"')
256
+ edge_properties = edge_data
257
+
258
+ async def _do_upsert_edge(tx: AsyncManagedTransaction):
259
+ query = f"""
260
+ MATCH (source:`{source_node_label}`)
261
+ WITH source
262
+ MATCH (target:`{target_node_label}`)
263
+ MERGE (source)-[r:DIRECTED]->(target)
264
+ SET r += $properties
265
+ RETURN r
266
+ """
267
+ await tx.run(query, properties=edge_properties)
268
+ logger.debug(f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}")
269
+
270
+ try:
271
+ async with self._driver.session() as session:
272
+ await session.execute_write(_do_upsert_edge)
273
+ except Exception as e:
274
+ logger.error(f"Error during edge upsert: {str(e)}")
275
+ raise
276
+ async def _node2vec_embed(self):
277
+ print ("Implemented but never called.")
278
+
lightrag/lightrag.py CHANGED
@@ -1,5 +1,6 @@
1
  import asyncio
2
  import os
 
3
  from dataclasses import asdict, dataclass, field
4
  from datetime import datetime
5
  from functools import partial
@@ -23,6 +24,18 @@ from .storage import (
23
  NanoVectorDBStorage,
24
  NetworkXStorage,
25
  )
 
 
 
 
 
 
 
 
 
 
 
 
26
  from .utils import (
27
  EmbeddingFunc,
28
  compute_mdhash_id,
@@ -44,18 +57,27 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
44
  try:
45
  loop = asyncio.get_running_loop()
46
  except RuntimeError:
47
- logger.info("Creating a new event loop in a sub-thread.")
48
- loop = asyncio.new_event_loop()
49
- asyncio.set_event_loop(loop)
 
50
  return loop
51
 
52
 
53
  @dataclass
54
  class LightRAG:
 
55
  working_dir: str = field(
56
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
57
  )
58
 
 
 
 
 
 
 
 
59
  # text chunking
60
  chunk_token_size: int = 1200
61
  chunk_overlap_token_size: int = 100
@@ -94,7 +116,6 @@ class LightRAG:
94
  key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
95
  vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
96
  vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
97
- graph_storage_cls: Type[BaseGraphStorage] = NetworkXStorage
98
  enable_llm_cache: bool = True
99
 
100
  # extension
@@ -104,11 +125,16 @@ class LightRAG:
104
  def __post_init__(self):
105
  log_file = os.path.join(self.working_dir, "lightrag.log")
106
  set_logger(log_file)
 
 
107
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
108
 
109
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
110
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
111
 
 
 
 
112
  if not os.path.exists(self.working_dir):
113
  logger.info(f"Creating working directory {self.working_dir}")
114
  os.makedirs(self.working_dir)
@@ -161,6 +187,12 @@ class LightRAG:
161
  **self.llm_model_kwargs,
162
  )
163
  )
 
 
 
 
 
 
164
 
165
  def insert(self, string_or_strings):
166
  loop = always_get_an_event_loop()
@@ -298,4 +330,4 @@ class LightRAG:
298
  if storage_inst is None:
299
  continue
300
  tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
301
- await asyncio.gather(*tasks)
 
1
  import asyncio
2
  import os
3
+ import importlib
4
  from dataclasses import asdict, dataclass, field
5
  from datetime import datetime
6
  from functools import partial
 
24
  NanoVectorDBStorage,
25
  NetworkXStorage,
26
  )
27
+
28
+ from .kg.neo4j_impl import (
29
+ Neo4JStorage
30
+ )
31
+ #future KG integrations
32
+
33
+ # from .kg.ArangoDB_impl import (
34
+ # GraphStorage as ArangoDBStorage
35
+ # )
36
+
37
+
38
+
39
  from .utils import (
40
  EmbeddingFunc,
41
  compute_mdhash_id,
 
57
  try:
58
  loop = asyncio.get_running_loop()
59
  except RuntimeError:
60
+ logger.info("Creating a new event loop in main thread.")
61
+ # loop = asyncio.new_event_loop()
62
+ # asyncio.set_event_loop(loop)
63
+ loop = asyncio.get_event_loop()
64
  return loop
65
 
66
 
67
  @dataclass
68
  class LightRAG:
69
+
70
  working_dir: str = field(
71
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
72
  )
73
 
74
+ kg: str = field(default="NetworkXStorage")
75
+
76
+ current_log_level = logger.level
77
+ log_level: str = field(default=current_log_level)
78
+
79
+
80
+
81
  # text chunking
82
  chunk_token_size: int = 1200
83
  chunk_overlap_token_size: int = 100
 
116
  key_string_value_json_storage_cls: Type[BaseKVStorage] = JsonKVStorage
117
  vector_db_storage_cls: Type[BaseVectorStorage] = NanoVectorDBStorage
118
  vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
 
119
  enable_llm_cache: bool = True
120
 
121
  # extension
 
125
  def __post_init__(self):
126
  log_file = os.path.join(self.working_dir, "lightrag.log")
127
  set_logger(log_file)
128
+ logger.setLevel(self.log_level)
129
+
130
  logger.info(f"Logger initialized for working directory: {self.working_dir}")
131
 
132
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
133
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
134
 
135
+ #@TODO: should move all storage setup here to leverage initial start params attached to self.
136
+ self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
137
+
138
  if not os.path.exists(self.working_dir):
139
  logger.info(f"Creating working directory {self.working_dir}")
140
  os.makedirs(self.working_dir)
 
187
  **self.llm_model_kwargs,
188
  )
189
  )
190
+ def _get_storage_class(self) -> Type[BaseGraphStorage]:
191
+ return {
192
+ "Neo4JStorage": Neo4JStorage,
193
+ "NetworkXStorage": NetworkXStorage,
194
+ # "ArangoDBStorage": ArangoDBStorage
195
+ }
196
 
197
  def insert(self, string_or_strings):
198
  loop = always_get_an_event_loop()
 
330
  if storage_inst is None:
331
  continue
332
  tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
333
+ await asyncio.gather(*tasks)
lightrag/llm.py CHANGED
@@ -798,4 +798,4 @@ if __name__ == "__main__":
798
  result = await gpt_4o_mini_complete("How are you?")
799
  print(result)
800
 
801
- asyncio.run(main())
 
798
  result = await gpt_4o_mini_complete("How are you?")
799
  print(result)
800
 
801
+ asyncio.run(main())
lightrag/operate.py CHANGED
@@ -466,7 +466,9 @@ async def _build_local_query_context(
466
  text_chunks_db: BaseKVStorage[TextChunkSchema],
467
  query_param: QueryParam,
468
  ):
 
469
  results = await entities_vdb.query(query, top_k=query_param.top_k)
 
470
  if not len(results):
471
  return None
472
  node_datas = await asyncio.gather(
@@ -481,7 +483,7 @@ async def _build_local_query_context(
481
  {**n, "entity_name": k["entity_name"], "rank": d}
482
  for k, n, d in zip(results, node_datas, node_degrees)
483
  if n is not None
484
- ]
485
  use_text_units = await _find_most_related_text_unit_from_entities(
486
  node_datas, query_param, text_chunks_db, knowledge_graph_inst
487
  )
@@ -907,7 +909,6 @@ async def hybrid_query(
907
  .strip()
908
  )
909
  result = "{" + result.split("{")[1].split("}")[0] + "}"
910
-
911
  keywords_data = json.loads(result)
912
  hl_keywords = keywords_data.get("high_level_keywords", [])
913
  ll_keywords = keywords_data.get("low_level_keywords", [])
@@ -927,6 +928,7 @@ async def hybrid_query(
927
  query_param,
928
  )
929
 
 
930
  if hl_keywords:
931
  high_level_context = await _build_global_query_context(
932
  hl_keywords,
@@ -937,6 +939,7 @@ async def hybrid_query(
937
  query_param,
938
  )
939
 
 
940
  context = combine_contexts(high_level_context, low_level_context)
941
 
942
  if query_param.only_need_context:
@@ -1043,6 +1046,7 @@ async def naive_query(
1043
  chunks_ids = [r["id"] for r in results]
1044
  chunks = await text_chunks_db.get_by_ids(chunks_ids)
1045
 
 
1046
  maybe_trun_chunks = truncate_list_by_token_size(
1047
  chunks,
1048
  key=lambda x: x["content"],
@@ -1073,4 +1077,4 @@ async def naive_query(
1073
  .strip()
1074
  )
1075
 
1076
- return response
 
466
  text_chunks_db: BaseKVStorage[TextChunkSchema],
467
  query_param: QueryParam,
468
  ):
469
+
470
  results = await entities_vdb.query(query, top_k=query_param.top_k)
471
+
472
  if not len(results):
473
  return None
474
  node_datas = await asyncio.gather(
 
483
  {**n, "entity_name": k["entity_name"], "rank": d}
484
  for k, n, d in zip(results, node_datas, node_degrees)
485
  if n is not None
486
+ ]#what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
487
  use_text_units = await _find_most_related_text_unit_from_entities(
488
  node_datas, query_param, text_chunks_db, knowledge_graph_inst
489
  )
 
909
  .strip()
910
  )
911
  result = "{" + result.split("{")[1].split("}")[0] + "}"
 
912
  keywords_data = json.loads(result)
913
  hl_keywords = keywords_data.get("high_level_keywords", [])
914
  ll_keywords = keywords_data.get("low_level_keywords", [])
 
928
  query_param,
929
  )
930
 
931
+
932
  if hl_keywords:
933
  high_level_context = await _build_global_query_context(
934
  hl_keywords,
 
939
  query_param,
940
  )
941
 
942
+
943
  context = combine_contexts(high_level_context, low_level_context)
944
 
945
  if query_param.only_need_context:
 
1046
  chunks_ids = [r["id"] for r in results]
1047
  chunks = await text_chunks_db.get_by_ids(chunks_ids)
1048
 
1049
+
1050
  maybe_trun_chunks = truncate_list_by_token_size(
1051
  chunks,
1052
  key=lambda x: x["content"],
 
1077
  .strip()
1078
  )
1079
 
1080
+ return response
lightrag/storage.py CHANGED
@@ -233,6 +233,8 @@ class NetworkXStorage(BaseGraphStorage):
233
  raise ValueError(f"Node embedding algorithm {algorithm} not supported")
234
  return await self._node_embed_algorithms[algorithm]()
235
 
 
 
236
  async def _node2vec_embed(self):
237
  from graspologic import embed
238
 
 
233
  raise ValueError(f"Node embedding algorithm {algorithm} not supported")
234
  return await self._node_embed_algorithms[algorithm]()
235
 
236
+
237
+ #@TODO: NOT USED
238
  async def _node2vec_embed(self):
239
  from graspologic import embed
240
 
requirements.txt CHANGED
@@ -4,6 +4,7 @@ aiohttp
4
  graspologic
5
  hnswlib
6
  nano-vectordb
 
7
  networkx
8
  ollama
9
  openai
 
4
  graspologic
5
  hnswlib
6
  nano-vectordb
7
+ neo4j
8
  networkx
9
  ollama
10
  openai
test.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from lightrag import LightRAG, QueryParam
3
+ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
4
+ from pprint import pprint
5
+ #########
6
+ # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
7
+ # import nest_asyncio
8
+ # nest_asyncio.apply()
9
+ #########
10
+
11
+ WORKING_DIR = "./dickens"
12
+
13
+ if not os.path.exists(WORKING_DIR):
14
+ os.mkdir(WORKING_DIR)
15
+
16
+ rag = LightRAG(
17
+ working_dir=WORKING_DIR,
18
+ llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
19
+ # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
20
+ )
21
+
22
+ with open("./book.txt") as f:
23
+ rag.insert(f.read())
24
+
25
+ # Perform naive search
26
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
27
+
28
+ # Perform local search
29
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
30
+
31
+ # Perform global search
32
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
33
+
34
+ # Perform hybrid search
35
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
test_neo4j.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from lightrag import LightRAG, QueryParam
3
+ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
4
+
5
+
6
+ #########
7
+ # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
8
+ # import nest_asyncio
9
+ # nest_asyncio.apply()
10
+ #########
11
+
12
+ WORKING_DIR = "./local_neo4jWorkDir"
13
+
14
+ if not os.path.exists(WORKING_DIR):
15
+ os.mkdir(WORKING_DIR)
16
+
17
+ rag = LightRAG(
18
+ working_dir=WORKING_DIR,
19
+ llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
20
+ kg="Neo4JStorage",
21
+ log_level="INFO"
22
+ # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
23
+ )
24
+
25
+ with open("./book.txt") as f:
26
+ rag.insert(f.read())
27
+
28
+ # Perform naive search
29
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
30
+
31
+ # Perform local search
32
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
33
+
34
+ # Perform global search
35
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
36
+
37
+ # Perform hybrid search
38
+ print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))