Ken Wiltshire commited on
Commit
50c91b7
·
1 Parent(s): bae5305

fix event loop conflict

Browse files
get_all_edges_nx.py CHANGED
@@ -1,28 +1,34 @@
1
  import networkx as nx
2
 
3
- G = nx.read_graphml('./dickensTestEmbedcall/graph_chunk_entity_relation.graphml')
 
4
 
5
  def get_all_edges_and_nodes(G):
6
  # Get all edges and their properties
7
  edges_with_properties = []
8
  for u, v, data in G.edges(data=True):
9
- edges_with_properties.append({
10
- 'start': u,
11
- 'end': v,
12
- 'label': data.get('label', ''), # Assuming 'label' is used for edge type
13
- 'properties': data,
14
- 'start_node_properties': G.nodes[u],
15
- 'end_node_properties': G.nodes[v]
16
- })
 
 
 
 
17
 
18
  return edges_with_properties
19
 
 
20
  # Example usage
21
  if __name__ == "__main__":
22
  # Assume G is your NetworkX graph loaded from Neo4j
23
 
24
  all_edges = get_all_edges_and_nodes(G)
25
-
26
  # Print all edges and node properties
27
  for edge in all_edges:
28
  print(f"Edge Label: {edge['label']}")
@@ -31,4 +37,4 @@ if __name__ == "__main__":
31
  print(f"Start Node Properties: {edge['start_node_properties']}")
32
  print(f"End Node: {edge['end']}")
33
  print(f"End Node Properties: {edge['end_node_properties']}")
34
- print("---")
 
1
  import networkx as nx
2
 
3
+ G = nx.read_graphml("./dickensTestEmbedcall/graph_chunk_entity_relation.graphml")
4
+
5
 
6
  def get_all_edges_and_nodes(G):
7
  # Get all edges and their properties
8
  edges_with_properties = []
9
  for u, v, data in G.edges(data=True):
10
+ edges_with_properties.append(
11
+ {
12
+ "start": u,
13
+ "end": v,
14
+ "label": data.get(
15
+ "label", ""
16
+ ), # Assuming 'label' is used for edge type
17
+ "properties": data,
18
+ "start_node_properties": G.nodes[u],
19
+ "end_node_properties": G.nodes[v],
20
+ }
21
+ )
22
 
23
  return edges_with_properties
24
 
25
+
26
  # Example usage
27
  if __name__ == "__main__":
28
  # Assume G is your NetworkX graph loaded from Neo4j
29
 
30
  all_edges = get_all_edges_and_nodes(G)
31
+
32
  # Print all edges and node properties
33
  for edge in all_edges:
34
  print(f"Edge Label: {edge['label']}")
 
37
  print(f"Start Node Properties: {edge['start_node_properties']}")
38
  print(f"End Node: {edge['end']}")
39
  print(f"End Node Properties: {edge['end_node_properties']}")
40
+ print("---")
lightrag/kg/neo4j_impl.py CHANGED
@@ -1,17 +1,16 @@
1
  import asyncio
2
- import html
3
  import os
4
  from dataclasses import dataclass
5
- from typing import Any, Union, cast, Tuple, List, Dict
6
- import numpy as np
7
  import inspect
8
- from lightrag.utils import load_json, logger, write_json
9
- from ..base import (
10
- BaseGraphStorage
 
 
 
 
11
  )
12
- from neo4j import AsyncGraphDatabase,exceptions as neo4jExceptions,AsyncDriver,AsyncSession, AsyncManagedTransaction
13
-
14
- from contextlib import asynccontextmanager
15
 
16
 
17
  from tenacity import (
@@ -26,7 +25,7 @@ from tenacity import (
26
  class Neo4JStorage(BaseGraphStorage):
27
  @staticmethod
28
  def load_nx_graph(file_name):
29
- print ("no preloading of graph with neo4j in production")
30
 
31
  def __init__(self, namespace, global_config):
32
  super().__init__(namespace=namespace, global_config=global_config)
@@ -35,7 +34,9 @@ class Neo4JStorage(BaseGraphStorage):
35
  URI = os.environ["NEO4J_URI"]
36
  USERNAME = os.environ["NEO4J_USERNAME"]
37
  PASSWORD = os.environ["NEO4J_PASSWORD"]
38
- self._driver: AsyncDriver = AsyncGraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
 
 
39
  return None
40
 
41
  def __post_init__(self):
@@ -43,59 +44,54 @@ class Neo4JStorage(BaseGraphStorage):
43
  "node2vec": self._node2vec_embed,
44
  }
45
 
46
-
47
  async def close(self):
48
  if self._driver:
49
  await self._driver.close()
50
  self._driver = None
51
 
52
-
53
-
54
  async def __aexit__(self, exc_type, exc, tb):
55
  if self._driver:
56
  await self._driver.close()
57
 
58
  async def index_done_callback(self):
59
- print ("KG successfully indexed.")
60
 
61
-
62
  async def has_node(self, node_id: str) -> bool:
63
- entity_name_label = node_id.strip('\"')
64
 
65
- async with self._driver.session() as session:
66
- query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
67
- result = await session.run(query)
 
 
68
  single_result = await result.single()
69
  logger.debug(
70
- f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
71
- )
72
  return single_result["node_exists"]
73
-
74
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
75
- entity_name_label_source = source_node_id.strip('\"')
76
- entity_name_label_target = target_node_id.strip('\"')
77
-
78
- async with self._driver.session() as session:
79
- query = (
80
- f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
81
- "RETURN COUNT(r) > 0 AS edgeExists"
82
- )
83
- result = await session.run(query)
84
  single_result = await result.single()
85
  logger.debug(
86
- f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
87
- )
88
  return single_result["edgeExists"]
89
-
90
- def close(self):
91
- self._driver.close()
92
-
93
-
94
 
 
 
95
 
96
  async def get_node(self, node_id: str) -> Union[dict, None]:
97
  async with self._driver.session() as session:
98
- entity_name_label = node_id.strip('\"')
99
  query = f"MATCH (n:`{entity_name_label}`) RETURN n"
100
  result = await session.run(query)
101
  record = await result.single()
@@ -103,54 +99,51 @@ class Neo4JStorage(BaseGraphStorage):
103
  node = record["n"]
104
  node_dict = dict(node)
105
  logger.debug(
106
- f'{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}'
107
  )
108
  return node_dict
109
  return None
110
-
111
-
112
 
113
  async def node_degree(self, node_id: str) -> int:
114
- entity_name_label = node_id.strip('\"')
115
 
116
- async with self._driver.session() as session:
117
  query = f"""
118
  MATCH (n:`{entity_name_label}`)
119
  RETURN COUNT{{ (n)--() }} AS totalEdgeCount
120
  """
121
- result = await session.run(query)
122
- record = await result.single()
123
  if record:
124
- edge_count = record["totalEdgeCount"]
125
  logger.debug(
126
- f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
127
- )
128
  return edge_count
129
- else:
130
  return None
131
-
132
 
133
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
134
- entity_name_label_source = src_id.strip('\"')
135
- entity_name_label_target = tgt_id.strip('\"')
136
  src_degree = await self.node_degree(entity_name_label_source)
137
  trg_degree = await self.node_degree(entity_name_label_target)
138
-
139
  # Convert None to 0 for addition
140
  src_degree = 0 if src_degree is None else src_degree
141
  trg_degree = 0 if trg_degree is None else trg_degree
142
 
143
  degrees = int(src_degree) + int(trg_degree)
144
  logger.debug(
145
- f'{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}'
146
- )
147
  return degrees
148
 
149
-
150
-
151
- async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
152
- entity_name_label_source = source_node_id.strip('\"')
153
- entity_name_label_target = target_node_id.strip('\"')
154
  """
155
  Find all edges between nodes of two given labels
156
 
@@ -161,28 +154,30 @@ class Neo4JStorage(BaseGraphStorage):
161
  Returns:
162
  list: List of all relationships/edges found
163
  """
164
- async with self._driver.session() as session:
165
  query = f"""
166
  MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
167
  RETURN properties(r) as edge_properties
168
  LIMIT 1
169
- """.format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
170
-
171
- result = await session.run(query)
 
 
 
172
  record = await result.single()
173
  if record:
174
  result = dict(record["edge_properties"])
175
  logger.debug(
176
- f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
177
- )
178
  return result
179
  else:
180
  return None
181
-
182
 
183
- async def get_node_edges(self, source_node_id: str)-> List[Tuple[str, str]]:
184
- node_label = source_node_id.strip('\"')
185
-
186
  """
187
  Retrieves all edges (relationships) for a particular node identified by its label.
188
  :return: List of dictionaries containing edge information
@@ -190,26 +185,37 @@ class Neo4JStorage(BaseGraphStorage):
190
  query = f"""MATCH (n:`{node_label}`)
191
  OPTIONAL MATCH (n)-[r]-(connected)
192
  RETURN n, r, connected"""
193
- async with self._driver.session() as session:
194
  results = await session.run(query)
195
  edges = []
196
  async for record in results:
197
- source_node = record['n']
198
- connected_node = record['connected']
199
-
200
- source_label = list(source_node.labels)[0] if source_node.labels else None
201
- target_label = list(connected_node.labels)[0] if connected_node and connected_node.labels else None
202
-
 
 
 
 
 
 
203
  if source_label and target_label:
204
  edges.append((source_label, target_label))
205
-
206
- return edges
207
 
 
208
 
209
  @retry(
210
  stop=stop_after_attempt(3),
211
  wait=wait_exponential(multiplier=1, min=4, max=10),
212
- retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
 
 
 
 
 
 
213
  )
214
  async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
215
  """
@@ -219,7 +225,7 @@ class Neo4JStorage(BaseGraphStorage):
219
  node_id: The unique identifier for the node (used as label)
220
  node_data: Dictionary of node properties
221
  """
222
- label = node_id.strip('\"')
223
  properties = node_data
224
 
225
  async def _do_upsert(tx: AsyncManagedTransaction):
@@ -228,7 +234,9 @@ class Neo4JStorage(BaseGraphStorage):
228
  SET n += $properties
229
  """
230
  await tx.run(query, properties=properties)
231
- logger.debug(f"Upserted node with label '{label}' and properties: {properties}")
 
 
232
 
233
  try:
234
  async with self._driver.session() as session:
@@ -236,13 +244,21 @@ class Neo4JStorage(BaseGraphStorage):
236
  except Exception as e:
237
  logger.error(f"Error during upsert: {str(e)}")
238
  raise
239
-
240
  @retry(
241
  stop=stop_after_attempt(3),
242
  wait=wait_exponential(multiplier=1, min=4, max=10),
243
- retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
 
 
 
 
 
 
244
  )
245
- async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]):
 
 
246
  """
247
  Upsert an edge and its properties between two nodes identified by their labels.
248
 
@@ -251,8 +267,8 @@ class Neo4JStorage(BaseGraphStorage):
251
  target_node_id (str): Label of the target node (used as identifier)
252
  edge_data (dict): Dictionary of properties to set on the edge
253
  """
254
- source_node_label = source_node_id.strip('\"')
255
- target_node_label = target_node_id.strip('\"')
256
  edge_properties = edge_data
257
 
258
  async def _do_upsert_edge(tx: AsyncManagedTransaction):
@@ -265,7 +281,9 @@ class Neo4JStorage(BaseGraphStorage):
265
  RETURN r
266
  """
267
  await tx.run(query, properties=edge_properties)
268
- logger.debug(f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}")
 
 
269
 
270
  try:
271
  async with self._driver.session() as session:
@@ -273,6 +291,6 @@ class Neo4JStorage(BaseGraphStorage):
273
  except Exception as e:
274
  logger.error(f"Error during edge upsert: {str(e)}")
275
  raise
 
276
  async def _node2vec_embed(self):
277
- print ("Implemented but never called.")
278
-
 
1
  import asyncio
 
2
  import os
3
  from dataclasses import dataclass
4
+ from typing import Any, Union, Tuple, List, Dict
 
5
  import inspect
6
+ from lightrag.utils import logger
7
+ from ..base import BaseGraphStorage
8
+ from neo4j import (
9
+ AsyncGraphDatabase,
10
+ exceptions as neo4jExceptions,
11
+ AsyncDriver,
12
+ AsyncManagedTransaction,
13
  )
 
 
 
14
 
15
 
16
  from tenacity import (
 
25
  class Neo4JStorage(BaseGraphStorage):
26
  @staticmethod
27
  def load_nx_graph(file_name):
28
+ print("no preloading of graph with neo4j in production")
29
 
30
  def __init__(self, namespace, global_config):
31
  super().__init__(namespace=namespace, global_config=global_config)
 
34
  URI = os.environ["NEO4J_URI"]
35
  USERNAME = os.environ["NEO4J_USERNAME"]
36
  PASSWORD = os.environ["NEO4J_PASSWORD"]
37
+ self._driver: AsyncDriver = AsyncGraphDatabase.driver(
38
+ URI, auth=(USERNAME, PASSWORD)
39
+ )
40
  return None
41
 
42
  def __post_init__(self):
 
44
  "node2vec": self._node2vec_embed,
45
  }
46
 
 
47
  async def close(self):
48
  if self._driver:
49
  await self._driver.close()
50
  self._driver = None
51
 
 
 
52
  async def __aexit__(self, exc_type, exc, tb):
53
  if self._driver:
54
  await self._driver.close()
55
 
56
  async def index_done_callback(self):
57
+ print("KG successfully indexed.")
58
 
 
59
  async def has_node(self, node_id: str) -> bool:
60
+ entity_name_label = node_id.strip('"')
61
 
62
+ async with self._driver.session() as session:
63
+ query = (
64
+ f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
65
+ )
66
+ result = await session.run(query)
67
  single_result = await result.single()
68
  logger.debug(
69
+ f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
70
+ )
71
  return single_result["node_exists"]
72
+
73
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
74
+ entity_name_label_source = source_node_id.strip('"')
75
+ entity_name_label_target = target_node_id.strip('"')
76
+
77
+ async with self._driver.session() as session:
78
+ query = (
79
+ f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
80
+ "RETURN COUNT(r) > 0 AS edgeExists"
81
+ )
82
+ result = await session.run(query)
83
  single_result = await result.single()
84
  logger.debug(
85
+ f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
86
+ )
87
  return single_result["edgeExists"]
 
 
 
 
 
88
 
89
+ def close(self):
90
+ self._driver.close()
91
 
92
  async def get_node(self, node_id: str) -> Union[dict, None]:
93
  async with self._driver.session() as session:
94
+ entity_name_label = node_id.strip('"')
95
  query = f"MATCH (n:`{entity_name_label}`) RETURN n"
96
  result = await session.run(query)
97
  record = await result.single()
 
99
  node = record["n"]
100
  node_dict = dict(node)
101
  logger.debug(
102
+ f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
103
  )
104
  return node_dict
105
  return None
 
 
106
 
107
  async def node_degree(self, node_id: str) -> int:
108
+ entity_name_label = node_id.strip('"')
109
 
110
+ async with self._driver.session() as session:
111
  query = f"""
112
  MATCH (n:`{entity_name_label}`)
113
  RETURN COUNT{{ (n)--() }} AS totalEdgeCount
114
  """
115
+ result = await session.run(query)
116
+ record = await result.single()
117
  if record:
118
+ edge_count = record["totalEdgeCount"]
119
  logger.debug(
120
+ f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
121
+ )
122
  return edge_count
123
+ else:
124
  return None
 
125
 
126
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
127
+ entity_name_label_source = src_id.strip('"')
128
+ entity_name_label_target = tgt_id.strip('"')
129
  src_degree = await self.node_degree(entity_name_label_source)
130
  trg_degree = await self.node_degree(entity_name_label_target)
131
+
132
  # Convert None to 0 for addition
133
  src_degree = 0 if src_degree is None else src_degree
134
  trg_degree = 0 if trg_degree is None else trg_degree
135
 
136
  degrees = int(src_degree) + int(trg_degree)
137
  logger.debug(
138
+ f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}"
139
+ )
140
  return degrees
141
 
142
+ async def get_edge(
143
+ self, source_node_id: str, target_node_id: str
144
+ ) -> Union[dict, None]:
145
+ entity_name_label_source = source_node_id.strip('"')
146
+ entity_name_label_target = target_node_id.strip('"')
147
  """
148
  Find all edges between nodes of two given labels
149
 
 
154
  Returns:
155
  list: List of all relationships/edges found
156
  """
157
+ async with self._driver.session() as session:
158
  query = f"""
159
  MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
160
  RETURN properties(r) as edge_properties
161
  LIMIT 1
162
+ """.format(
163
+ entity_name_label_source=entity_name_label_source,
164
+ entity_name_label_target=entity_name_label_target,
165
+ )
166
+
167
+ result = await session.run(query)
168
  record = await result.single()
169
  if record:
170
  result = dict(record["edge_properties"])
171
  logger.debug(
172
+ f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
173
+ )
174
  return result
175
  else:
176
  return None
 
177
 
178
+ async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
179
+ node_label = source_node_id.strip('"')
180
+
181
  """
182
  Retrieves all edges (relationships) for a particular node identified by its label.
183
  :return: List of dictionaries containing edge information
 
185
  query = f"""MATCH (n:`{node_label}`)
186
  OPTIONAL MATCH (n)-[r]-(connected)
187
  RETURN n, r, connected"""
188
+ async with self._driver.session() as session:
189
  results = await session.run(query)
190
  edges = []
191
  async for record in results:
192
+ source_node = record["n"]
193
+ connected_node = record["connected"]
194
+
195
+ source_label = (
196
+ list(source_node.labels)[0] if source_node.labels else None
197
+ )
198
+ target_label = (
199
+ list(connected_node.labels)[0]
200
+ if connected_node and connected_node.labels
201
+ else None
202
+ )
203
+
204
  if source_label and target_label:
205
  edges.append((source_label, target_label))
 
 
206
 
207
+ return edges
208
 
209
  @retry(
210
  stop=stop_after_attempt(3),
211
  wait=wait_exponential(multiplier=1, min=4, max=10),
212
+ retry=retry_if_exception_type(
213
+ (
214
+ neo4jExceptions.ServiceUnavailable,
215
+ neo4jExceptions.TransientError,
216
+ neo4jExceptions.WriteServiceUnavailable,
217
+ )
218
+ ),
219
  )
220
  async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
221
  """
 
225
  node_id: The unique identifier for the node (used as label)
226
  node_data: Dictionary of node properties
227
  """
228
+ label = node_id.strip('"')
229
  properties = node_data
230
 
231
  async def _do_upsert(tx: AsyncManagedTransaction):
 
234
  SET n += $properties
235
  """
236
  await tx.run(query, properties=properties)
237
+ logger.debug(
238
+ f"Upserted node with label '{label}' and properties: {properties}"
239
+ )
240
 
241
  try:
242
  async with self._driver.session() as session:
 
244
  except Exception as e:
245
  logger.error(f"Error during upsert: {str(e)}")
246
  raise
247
+
248
  @retry(
249
  stop=stop_after_attempt(3),
250
  wait=wait_exponential(multiplier=1, min=4, max=10),
251
+ retry=retry_if_exception_type(
252
+ (
253
+ neo4jExceptions.ServiceUnavailable,
254
+ neo4jExceptions.TransientError,
255
+ neo4jExceptions.WriteServiceUnavailable,
256
+ )
257
+ ),
258
  )
259
+ async def upsert_edge(
260
+ self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
261
+ ):
262
  """
263
  Upsert an edge and its properties between two nodes identified by their labels.
264
 
 
267
  target_node_id (str): Label of the target node (used as identifier)
268
  edge_data (dict): Dictionary of properties to set on the edge
269
  """
270
+ source_node_label = source_node_id.strip('"')
271
+ target_node_label = target_node_id.strip('"')
272
  edge_properties = edge_data
273
 
274
  async def _do_upsert_edge(tx: AsyncManagedTransaction):
 
281
  RETURN r
282
  """
283
  await tx.run(query, properties=edge_properties)
284
+ logger.debug(
285
+ f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}"
286
+ )
287
 
288
  try:
289
  async with self._driver.session() as session:
 
291
  except Exception as e:
292
  logger.error(f"Error during edge upsert: {str(e)}")
293
  raise
294
+
295
  async def _node2vec_embed(self):
296
+ print("Implemented but never called.")
 
lightrag/lightrag.py CHANGED
@@ -1,6 +1,5 @@
1
  import asyncio
2
  import os
3
- import importlib
4
  from dataclasses import asdict, dataclass, field
5
  from datetime import datetime
6
  from functools import partial
@@ -24,18 +23,15 @@ from .storage import (
24
  NanoVectorDBStorage,
25
  NetworkXStorage,
26
  )
27
-
28
- from .kg.neo4j_impl import (
29
- Neo4JStorage
30
- )
31
- #future KG integrations
32
 
33
  # from .kg.ArangoDB_impl import (
34
  # GraphStorage as ArangoDBStorage
35
  # )
36
 
37
 
38
-
39
  from .utils import (
40
  EmbeddingFunc,
41
  compute_mdhash_id,
@@ -52,6 +48,7 @@ from .base import (
52
  QueryParam,
53
  )
54
 
 
55
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
56
  try:
57
  loop = asyncio.get_event_loop()
@@ -64,7 +61,6 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
64
 
65
  @dataclass
66
  class LightRAG:
67
-
68
  working_dir: str = field(
69
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
70
  )
@@ -74,8 +70,6 @@ class LightRAG:
74
  current_log_level = logger.level
75
  log_level: str = field(default=current_log_level)
76
 
77
-
78
-
79
  # text chunking
80
  chunk_token_size: int = 1200
81
  chunk_overlap_token_size: int = 100
@@ -130,8 +124,10 @@ class LightRAG:
130
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
131
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
132
 
133
- #@TODO: should move all storage setup here to leverage initial start params attached to self.
134
- self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[self.kg]
 
 
135
 
136
  if not os.path.exists(self.working_dir):
137
  logger.info(f"Creating working directory {self.working_dir}")
@@ -185,6 +181,7 @@ class LightRAG:
185
  **self.llm_model_kwargs,
186
  )
187
  )
 
188
  def _get_storage_class(self) -> Type[BaseGraphStorage]:
189
  return {
190
  "Neo4JStorage": Neo4JStorage,
@@ -328,4 +325,4 @@ class LightRAG:
328
  if storage_inst is None:
329
  continue
330
  tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
331
- await asyncio.gather(*tasks)
 
1
  import asyncio
2
  import os
 
3
  from dataclasses import asdict, dataclass, field
4
  from datetime import datetime
5
  from functools import partial
 
23
  NanoVectorDBStorage,
24
  NetworkXStorage,
25
  )
26
+
27
+ from .kg.neo4j_impl import Neo4JStorage
28
+ # future KG integrations
 
 
29
 
30
  # from .kg.ArangoDB_impl import (
31
  # GraphStorage as ArangoDBStorage
32
  # )
33
 
34
 
 
35
  from .utils import (
36
  EmbeddingFunc,
37
  compute_mdhash_id,
 
48
  QueryParam,
49
  )
50
 
51
+
52
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
53
  try:
54
  loop = asyncio.get_event_loop()
 
61
 
62
  @dataclass
63
  class LightRAG:
 
64
  working_dir: str = field(
65
  default_factory=lambda: f"./lightrag_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
66
  )
 
70
  current_log_level = logger.level
71
  log_level: str = field(default=current_log_level)
72
 
 
 
73
  # text chunking
74
  chunk_token_size: int = 1200
75
  chunk_overlap_token_size: int = 100
 
124
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in asdict(self).items()])
125
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
126
 
127
+ # @TODO: should move all storage setup here to leverage initial start params attached to self.
128
+ self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
129
+ self.kg
130
+ ]
131
 
132
  if not os.path.exists(self.working_dir):
133
  logger.info(f"Creating working directory {self.working_dir}")
 
181
  **self.llm_model_kwargs,
182
  )
183
  )
184
+
185
  def _get_storage_class(self) -> Type[BaseGraphStorage]:
186
  return {
187
  "Neo4JStorage": Neo4JStorage,
 
325
  if storage_inst is None:
326
  continue
327
  tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
328
+ await asyncio.gather(*tasks)
lightrag/llm.py CHANGED
@@ -798,4 +798,4 @@ if __name__ == "__main__":
798
  result = await gpt_4o_mini_complete("How are you?")
799
  print(result)
800
 
801
- asyncio.run(main())
 
798
  result = await gpt_4o_mini_complete("How are you?")
799
  print(result)
800
 
801
+ asyncio.run(main())
lightrag/operate.py CHANGED
@@ -466,7 +466,6 @@ async def _build_local_query_context(
466
  text_chunks_db: BaseKVStorage[TextChunkSchema],
467
  query_param: QueryParam,
468
  ):
469
-
470
  results = await entities_vdb.query(query, top_k=query_param.top_k)
471
 
472
  if not len(results):
@@ -483,7 +482,7 @@ async def _build_local_query_context(
483
  {**n, "entity_name": k["entity_name"], "rank": d}
484
  for k, n, d in zip(results, node_datas, node_degrees)
485
  if n is not None
486
- ]#what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
487
  use_text_units = await _find_most_related_text_unit_from_entities(
488
  node_datas, query_param, text_chunks_db, knowledge_graph_inst
489
  )
@@ -928,7 +927,6 @@ async def hybrid_query(
928
  query_param,
929
  )
930
 
931
-
932
  if hl_keywords:
933
  high_level_context = await _build_global_query_context(
934
  hl_keywords,
@@ -939,7 +937,6 @@ async def hybrid_query(
939
  query_param,
940
  )
941
 
942
-
943
  context = combine_contexts(high_level_context, low_level_context)
944
 
945
  if query_param.only_need_context:
@@ -1008,9 +1005,11 @@ def combine_contexts(high_level_context, low_level_context):
1008
 
1009
  # Combine and deduplicate the entities
1010
  combined_entities = process_combine_contexts(hl_entities, ll_entities)
1011
-
1012
  # Combine and deduplicate the relationships
1013
- combined_relationships = process_combine_contexts(hl_relationships, ll_relationships)
 
 
1014
 
1015
  # Combine and deduplicate the sources
1016
  combined_sources = process_combine_contexts(hl_sources, ll_sources)
@@ -1046,7 +1045,6 @@ async def naive_query(
1046
  chunks_ids = [r["id"] for r in results]
1047
  chunks = await text_chunks_db.get_by_ids(chunks_ids)
1048
 
1049
-
1050
  maybe_trun_chunks = truncate_list_by_token_size(
1051
  chunks,
1052
  key=lambda x: x["content"],
@@ -1077,4 +1075,4 @@ async def naive_query(
1077
  .strip()
1078
  )
1079
 
1080
- return response
 
466
  text_chunks_db: BaseKVStorage[TextChunkSchema],
467
  query_param: QueryParam,
468
  ):
 
469
  results = await entities_vdb.query(query, top_k=query_param.top_k)
470
 
471
  if not len(results):
 
482
  {**n, "entity_name": k["entity_name"], "rank": d}
483
  for k, n, d in zip(results, node_datas, node_degrees)
484
  if n is not None
485
+ ] # what is this text_chunks_db doing. dont remember it in airvx. check the diagram.
486
  use_text_units = await _find_most_related_text_unit_from_entities(
487
  node_datas, query_param, text_chunks_db, knowledge_graph_inst
488
  )
 
927
  query_param,
928
  )
929
 
 
930
  if hl_keywords:
931
  high_level_context = await _build_global_query_context(
932
  hl_keywords,
 
937
  query_param,
938
  )
939
 
 
940
  context = combine_contexts(high_level_context, low_level_context)
941
 
942
  if query_param.only_need_context:
 
1005
 
1006
  # Combine and deduplicate the entities
1007
  combined_entities = process_combine_contexts(hl_entities, ll_entities)
1008
+
1009
  # Combine and deduplicate the relationships
1010
+ combined_relationships = process_combine_contexts(
1011
+ hl_relationships, ll_relationships
1012
+ )
1013
 
1014
  # Combine and deduplicate the sources
1015
  combined_sources = process_combine_contexts(hl_sources, ll_sources)
 
1045
  chunks_ids = [r["id"] for r in results]
1046
  chunks = await text_chunks_db.get_by_ids(chunks_ids)
1047
 
 
1048
  maybe_trun_chunks = truncate_list_by_token_size(
1049
  chunks,
1050
  key=lambda x: x["content"],
 
1075
  .strip()
1076
  )
1077
 
1078
+ return response
lightrag/storage.py CHANGED
@@ -233,8 +233,7 @@ class NetworkXStorage(BaseGraphStorage):
233
  raise ValueError(f"Node embedding algorithm {algorithm} not supported")
234
  return await self._node_embed_algorithms[algorithm]()
235
 
236
-
237
- #@TODO: NOT USED
238
  async def _node2vec_embed(self):
239
  from graspologic import embed
240
 
 
233
  raise ValueError(f"Node embedding algorithm {algorithm} not supported")
234
  return await self._node_embed_algorithms[algorithm]()
235
 
236
+ # @TODO: NOT USED
 
237
  async def _node2vec_embed(self):
238
  from graspologic import embed
239
 
lightrag/utils.py CHANGED
@@ -9,7 +9,7 @@ import re
9
  from dataclasses import dataclass
10
  from functools import wraps
11
  from hashlib import md5
12
- from typing import Any, Union,List
13
  import xml.etree.ElementTree as ET
14
 
15
  import numpy as np
@@ -176,19 +176,20 @@ def truncate_list_by_token_size(list_data: list, key: callable, max_token_size:
176
  return list_data[:i]
177
  return list_data
178
 
 
179
  def list_of_list_to_csv(data: List[List[str]]) -> str:
180
  output = io.StringIO()
181
  writer = csv.writer(output)
182
  writer.writerows(data)
183
  return output.getvalue()
 
 
184
  def csv_string_to_list(csv_string: str) -> List[List[str]]:
185
  output = io.StringIO(csv_string)
186
  reader = csv.reader(output)
187
  return [row for row in reader]
188
 
189
 
190
-
191
-
192
  def save_data_to_file(data, file_name):
193
  with open(file_name, "w", encoding="utf-8") as f:
194
  json.dump(data, f, ensure_ascii=False, indent=4)
@@ -253,13 +254,14 @@ def xml_to_json(xml_file):
253
  print(f"An error occurred: {e}")
254
  return None
255
 
 
256
  def process_combine_contexts(hl, ll):
257
  header = None
258
  list_hl = csv_string_to_list(hl.strip())
259
  list_ll = csv_string_to_list(ll.strip())
260
 
261
  if list_hl:
262
- header=list_hl[0]
263
  list_hl = list_hl[1:]
264
  if list_ll:
265
  header = list_ll[0]
@@ -268,19 +270,17 @@ def process_combine_contexts(hl, ll):
268
  return ""
269
 
270
  if list_hl:
271
- list_hl = [','.join(item[1:]) for item in list_hl if item]
272
  if list_ll:
273
- list_ll = [','.join(item[1:]) for item in list_ll if item]
274
 
275
- combined_sources_set = set(
276
- filter(None, list_hl + list_ll)
277
- )
278
 
279
  combined_sources = [",\t".join(header)]
280
 
281
  for i, item in enumerate(combined_sources_set, start=1):
282
  combined_sources.append(f"{i},\t{item}")
283
-
284
  combined_sources = "\n".join(combined_sources)
285
 
286
  return combined_sources
 
9
  from dataclasses import dataclass
10
  from functools import wraps
11
  from hashlib import md5
12
+ from typing import Any, Union, List
13
  import xml.etree.ElementTree as ET
14
 
15
  import numpy as np
 
176
  return list_data[:i]
177
  return list_data
178
 
179
+
180
  def list_of_list_to_csv(data: List[List[str]]) -> str:
181
  output = io.StringIO()
182
  writer = csv.writer(output)
183
  writer.writerows(data)
184
  return output.getvalue()
185
+
186
+
187
  def csv_string_to_list(csv_string: str) -> List[List[str]]:
188
  output = io.StringIO(csv_string)
189
  reader = csv.reader(output)
190
  return [row for row in reader]
191
 
192
 
 
 
193
  def save_data_to_file(data, file_name):
194
  with open(file_name, "w", encoding="utf-8") as f:
195
  json.dump(data, f, ensure_ascii=False, indent=4)
 
254
  print(f"An error occurred: {e}")
255
  return None
256
 
257
+
258
  def process_combine_contexts(hl, ll):
259
  header = None
260
  list_hl = csv_string_to_list(hl.strip())
261
  list_ll = csv_string_to_list(ll.strip())
262
 
263
  if list_hl:
264
+ header = list_hl[0]
265
  list_hl = list_hl[1:]
266
  if list_ll:
267
  header = list_ll[0]
 
270
  return ""
271
 
272
  if list_hl:
273
+ list_hl = [",".join(item[1:]) for item in list_hl if item]
274
  if list_ll:
275
+ list_ll = [",".join(item[1:]) for item in list_ll if item]
276
 
277
+ combined_sources_set = set(filter(None, list_hl + list_ll))
 
 
278
 
279
  combined_sources = [",\t".join(header)]
280
 
281
  for i, item in enumerate(combined_sources_set, start=1):
282
  combined_sources.append(f"{i},\t{item}")
283
+
284
  combined_sources = "\n".join(combined_sources)
285
 
286
  return combined_sources
test.py CHANGED
@@ -1,11 +1,10 @@
1
  import os
2
  from lightrag import LightRAG, QueryParam
3
  from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
4
- from pprint import pprint
5
  #########
6
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
7
- # import nest_asyncio
8
- # nest_asyncio.apply()
9
  #########
10
 
11
  WORKING_DIR = "./dickens"
@@ -15,7 +14,7 @@ if not os.path.exists(WORKING_DIR):
15
 
16
  rag = LightRAG(
17
  working_dir=WORKING_DIR,
18
- llm_model_func=gpt_4o_mini_complete # Use gpt_4o_mini_complete LLM model
19
  # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
20
  )
21
 
@@ -23,13 +22,21 @@ with open("./book.txt") as f:
23
  rag.insert(f.read())
24
 
25
  # Perform naive search
26
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
 
 
27
 
28
  # Perform local search
29
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
 
 
30
 
31
  # Perform global search
32
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
 
 
33
 
34
  # Perform hybrid search
35
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
 
 
 
1
  import os
2
  from lightrag import LightRAG, QueryParam
3
  from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
 
4
  #########
5
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
6
+ # import nest_asyncio
7
+ # nest_asyncio.apply()
8
  #########
9
 
10
  WORKING_DIR = "./dickens"
 
14
 
15
  rag = LightRAG(
16
  working_dir=WORKING_DIR,
17
+ llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
18
  # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
19
  )
20
 
 
22
  rag.insert(f.read())
23
 
24
  # Perform naive search
25
+ print(
26
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
27
+ )
28
 
29
  # Perform local search
30
+ print(
31
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
32
+ )
33
 
34
  # Perform global search
35
+ print(
36
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
37
+ )
38
 
39
  # Perform hybrid search
40
+ print(
41
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
42
+ )
test_neo4j.py CHANGED
@@ -5,8 +5,8 @@ from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
5
 
6
  #########
7
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
8
- # import nest_asyncio
9
- # nest_asyncio.apply()
10
  #########
11
 
12
  WORKING_DIR = "./local_neo4jWorkDir"
@@ -18,7 +18,7 @@ rag = LightRAG(
18
  working_dir=WORKING_DIR,
19
  llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
20
  kg="Neo4JStorage",
21
- log_level="INFO"
22
  # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
23
  )
24
 
@@ -26,13 +26,21 @@ with open("./book.txt") as f:
26
  rag.insert(f.read())
27
 
28
  # Perform naive search
29
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="naive")))
 
 
30
 
31
  # Perform local search
32
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="local")))
 
 
33
 
34
  # Perform global search
35
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="global")))
 
 
36
 
37
  # Perform hybrid search
38
- print(rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid")))
 
 
 
5
 
6
  #########
7
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
8
+ # import nest_asyncio
9
+ # nest_asyncio.apply()
10
  #########
11
 
12
  WORKING_DIR = "./local_neo4jWorkDir"
 
18
  working_dir=WORKING_DIR,
19
  llm_model_func=gpt_4o_mini_complete, # Use gpt_4o_mini_complete LLM model
20
  kg="Neo4JStorage",
21
+ log_level="INFO",
22
  # llm_model_func=gpt_4o_complete # Optionally, use a stronger model
23
  )
24
 
 
26
  rag.insert(f.read())
27
 
28
  # Perform naive search
29
+ print(
30
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
31
+ )
32
 
33
  # Perform local search
34
+ print(
35
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
36
+ )
37
 
38
  # Perform global search
39
+ print(
40
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
41
+ )
42
 
43
  # Perform hybrid search
44
+ print(
45
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
46
+ )