Ken Wiltshire commited on
Commit
ccf6919
·
1 Parent(s): f632fdf

using neo4j async

Browse files
lightrag/kg/neo4j_impl.py CHANGED
@@ -2,14 +2,16 @@ import asyncio
2
  import html
3
  import os
4
  from dataclasses import dataclass
5
- from typing import Any, Union, cast
6
  import numpy as np
7
  import inspect
8
  from lightrag.utils import load_json, logger, write_json
9
  from ..base import (
10
  BaseGraphStorage
11
  )
12
- from neo4j import GraphDatabase, exceptions as neo4jExceptions
 
 
13
 
14
 
15
  from tenacity import (
@@ -20,126 +22,135 @@ from tenacity import (
20
  )
21
 
22
 
23
-
24
  @dataclass
25
- class GraphStorage(BaseGraphStorage):
26
  @staticmethod
27
  def load_nx_graph(file_name):
28
  print ("no preloading of graph with neo4j in production")
29
 
 
 
 
 
 
 
 
 
 
 
30
  def __post_init__(self):
31
  # self._graph = preloaded_graph or nx.Graph()
 
32
  credetial_parts = ['URI', 'USERNAME','PASSWORD']
33
  credentials_set = all(x in os.environ for x in credetial_parts )
34
- if credentials_set:
35
- URI = os.environ["NEO4J_URI"]
36
- USERNAME = os.environ["NEO4J_USERNAME"]
37
- PASSWORD = os.environ["NEO4J_PASSWORD"]
38
- else:
39
- raise Exception (f"One or more Neo4J Credentials, {credetial_parts}, not found in the environment")
40
-
41
- self._driver = GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
42
  self._node_embed_algorithms = {
43
  "node2vec": self._node2vec_embed,
44
  }
45
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  async def index_done_callback(self):
47
  print ("KG successfully indexed.")
 
 
48
  async def has_node(self, node_id: str) -> bool:
49
  entity_name_label = node_id.strip('\"')
50
 
51
- def _check_node_exists(tx, label):
52
- query = f"MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"
53
- result = tx.run(query)
54
- single_result = result.single()
55
  logger.debug(
56
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
57
  )
58
-
59
  return single_result["node_exists"]
60
-
61
- with self._driver.session() as session:
62
- return session.read_transaction(_check_node_exists, entity_name_label)
63
-
64
-
65
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
66
  entity_name_label_source = source_node_id.strip('\"')
67
  entity_name_label_target = target_node_id.strip('\"')
68
 
69
-
70
- def _check_edge_existence(tx, label1, label2):
71
  query = (
72
- f"MATCH (a:`{label1}`)-[r]-(b:`{label2}`) "
73
  "RETURN COUNT(r) > 0 AS edgeExists"
74
  )
75
- result = tx.run(query)
76
- single_result = result.single()
77
  logger.debug(
78
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
79
  )
80
-
81
  return single_result["edgeExists"]
 
82
  def close(self):
83
  self._driver.close()
84
- #hard code relaitionship type, directed.
85
- with self._driver.session() as session:
86
- result = session.read_transaction(_check_edge_existence, entity_name_label_source, entity_name_label_target)
87
- return result
88
 
89
 
90
 
91
  async def get_node(self, node_id: str) -> Union[dict, None]:
92
- entity_name_label = node_id.strip('\"')
93
- with self._driver.session() as session:
94
- query = "MATCH (n:`{entity_name_label}`) RETURN n".format(entity_name_label=entity_name_label)
95
- result = session.run(query)
96
- for record in result:
97
- result = record["n"]
 
 
98
  logger.debug(
99
- f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}'
100
- )
101
- return result
102
-
 
103
 
104
 
105
  async def node_degree(self, node_id: str) -> int:
106
  entity_name_label = node_id.strip('\"')
107
 
108
-
109
- def _find_node_degree(session, label):
110
- with session.begin_transaction() as tx:
111
- query = f"""
112
- MATCH (n:`{label}`)
113
- RETURN COUNT{{ (n)--() }} AS totalEdgeCount
114
- """
115
- result = tx.run(query)
116
- record = result.single()
117
- if record:
118
- edge_count = record["totalEdgeCount"]
119
- logger.debug(
120
- f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
121
- )
122
- return edge_count
123
- else:
124
- return None
125
 
126
- with self._driver.session() as session:
127
- degree = _find_node_degree(session, entity_name_label)
128
- return degree
129
-
130
 
131
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
132
  entity_name_label_source = src_id.strip('\"')
133
  entity_name_label_target = tgt_id.strip('\"')
134
- with self._driver.session() as session:
135
- query = f"""MATCH (n1:`{entity_name_label_source}`)-[r]-(n2:`{entity_name_label_target}`)
136
- RETURN count(r) AS degree"""
137
- result = session.run(query)
138
- record = result.single()
139
- logger.debug(
140
- f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{record["degree"]}'
141
- )
142
- return record["degree"]
 
 
 
 
 
143
 
144
  async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
145
  entity_name_label_source = source_node_id.strip('\"')
@@ -154,15 +165,15 @@ class GraphStorage(BaseGraphStorage):
154
  Returns:
155
  list: List of all relationships/edges found
156
  """
157
- with self._driver.session() as session:
158
  query = f"""
159
  MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
160
  RETURN properties(r) as edge_properties
161
  LIMIT 1
162
  """.format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
163
 
164
- result = session.run(query)
165
- record = result.single()
166
  if record:
167
  result = dict(record["edge_properties"])
168
  logger.debug(
@@ -173,29 +184,20 @@ class GraphStorage(BaseGraphStorage):
173
  return None
174
 
175
 
176
- async def get_node_edges(self, source_node_id: str):
177
  node_label = source_node_id.strip('\"')
178
 
179
  """
180
- Retrieves all edges (relationships) for a particular node identified by its label and ID.
181
-
182
- :param uri: Neo4j database URI
183
- :param username: Neo4j username
184
- :param password: Neo4j password
185
- :param node_label: Label of the node
186
- :param node_id: ID property of the node
187
  :return: List of dictionaries containing edge information
188
  """
189
-
190
- def fetch_edges(tx, label):
191
- query = f"""MATCH (n:`{label}`)
192
  OPTIONAL MATCH (n)-[r]-(connected)
193
  RETURN n, r, connected"""
194
-
195
- results = tx.run(query)
196
-
197
  edges = []
198
- for record in results:
199
  source_node = record['n']
200
  connected_node = record['connected']
201
 
@@ -207,7 +209,7 @@ class GraphStorage(BaseGraphStorage):
207
 
208
  return edges
209
 
210
- with self._driver.session() as session:
211
  edges = session.read_transaction(fetch_edges,node_label)
212
  return edges
213
 
@@ -217,86 +219,51 @@ class GraphStorage(BaseGraphStorage):
217
  wait=wait_exponential(multiplier=1, min=4, max=10),
218
  retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
219
  )
220
- async def upsert_node(self, node_id: str, node_data: dict[str, str]):
221
- label = node_id.strip('\"')
222
- properties = node_data
223
  """
224
- Upsert a node with the given label and properties within a transaction.
 
225
  Args:
226
- label: The node label to search for and apply
227
- properties: Dictionary of node properties
228
-
229
- Returns:
230
- Dictionary containing the node's properties after upsert, or None if operation fails
231
  """
232
- def _do_upsert(tx, label: str, properties: dict[str, Any]):
233
-
234
- """
235
- Args:
236
- tx: Neo4j transaction object
237
- label: The node label to search for and apply
238
- properties: Dictionary of node properties
239
-
240
- Returns:
241
- Dictionary containing the node's properties after upsert, or None if operation fails
242
- """
243
 
 
244
  query = f"""
245
  MERGE (n:`{label}`)
246
  SET n += $properties
247
- RETURN n
248
  """
249
- # Execute the query with properties as parameters
250
- # with session.begin_transaction() as tx:
251
- result = tx.run(query, properties=properties)
252
- record = result.single()
253
- if record:
254
- logger.debug(
255
- f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{dict(record["n"])}'
256
- )
257
- return dict(record["n"])
258
- return None
259
-
260
-
261
- with self._driver.session() as session:
262
- with session.begin_transaction() as tx:
263
- try:
264
- result = _do_upsert(tx,label,properties)
265
- tx.commit()
266
- return result
267
- except Exception as e:
268
- raise # roll back
269
-
270
 
271
-
272
- async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]) -> None:
273
- source_node_label = source_node_id.strip('\"')
274
- target_node_label = target_node_id.strip('\"')
275
- edge_properties = edge_data
 
276
  """
277
  Upsert an edge and its properties between two nodes identified by their labels.
278
-
279
  Args:
280
- source_node_label (str): Label of the source node (used as identifier)
281
- target_node_label (str): Label of the target node (used as identifier)
282
- edge_properties (dict): Dictionary of properties to set on the edge
283
  """
284
-
285
-
286
-
287
- def _do_upsert_edge(tx, source_node_label: str, target_node_label: str, edge_properties: dict[str, Any]) -> None:
288
- """
289
- Static method to perform the edge upsert within a transaction.
290
-
291
- The query will:
292
- 1. Match the source and target nodes by their labels
293
- 2. Merge the DIRECTED relationship
294
- 3. Set all properties on the relationship, updating existing ones and adding new ones
295
- """
296
- # Convert edge properties to Cypher parameter string
297
- # props_string = ", ".join(f"r.{key} = ${key}" for key in edge_properties.keys())
298
 
299
- # """.format(props_string)
300
  query = f"""
301
  MATCH (source:`{source_node_label}`)
302
  WITH source
@@ -305,22 +272,15 @@ class GraphStorage(BaseGraphStorage):
305
  SET r += $properties
306
  RETURN r
307
  """
308
-
309
- result = tx.run(query, properties=edge_properties)
310
- logger.debug(
311
- f'{inspect.currentframe().f_code.co_name}:query:{query}:edge_properties:{edge_properties}'
312
- )
313
- return result.single()
314
-
315
- with self._driver.session() as session:
316
- session.execute_write(
317
- _do_upsert_edge,
318
- source_node_label,
319
- target_node_label,
320
- edge_properties
321
- )
322
- # return result
323
-
324
  async def _node2vec_embed(self):
325
  print ("Implemented but never called.")
326
 
 
2
  import html
3
  import os
4
  from dataclasses import dataclass
5
+ from typing import Any, Union, cast, Tuple, List, Dict
6
  import numpy as np
7
  import inspect
8
  from lightrag.utils import load_json, logger, write_json
9
  from ..base import (
10
  BaseGraphStorage
11
  )
12
+ from neo4j import AsyncGraphDatabase,exceptions as neo4jExceptions,AsyncDriver,AsyncSession, AsyncManagedTransaction
13
+
14
+ from contextlib import asynccontextmanager
15
 
16
 
17
  from tenacity import (
 
22
  )
23
 
24
 
 
25
  @dataclass
26
+ class Neo4JStorage(BaseGraphStorage):
27
  @staticmethod
28
  def load_nx_graph(file_name):
29
  print ("no preloading of graph with neo4j in production")
30
 
31
+ def __init__(self, namespace, global_config):
32
+ super().__init__(namespace=namespace, global_config=global_config)
33
+ self._driver = None
34
+ self._driver_lock = asyncio.Lock()
35
+ URI = os.environ["NEO4J_URI"]
36
+ USERNAME = os.environ["NEO4J_USERNAME"]
37
+ PASSWORD = os.environ["NEO4J_PASSWORD"]
38
+ self._driver: AsyncDriver = AsyncGraphDatabase.driver(URI, auth=(USERNAME, PASSWORD))
39
+ return None
40
+
41
  def __post_init__(self):
42
  # self._graph = preloaded_graph or nx.Graph()
43
+ print("is this ever run")
44
  credetial_parts = ['URI', 'USERNAME','PASSWORD']
45
  credentials_set = all(x in os.environ for x in credetial_parts )
 
 
 
 
 
 
 
 
46
  self._node_embed_algorithms = {
47
  "node2vec": self._node2vec_embed,
48
  }
49
 
50
+
51
+ async def close(self):
52
+ if self._driver:
53
+ await self._driver.close()
54
+ self._driver = None
55
+
56
+
57
+
58
+ async def __aexit__(self, exc_type, exc, tb):
59
+ if self._driver:
60
+ await self._driver.close()
61
+
62
  async def index_done_callback(self):
63
  print ("KG successfully indexed.")
64
+
65
+
66
  async def has_node(self, node_id: str) -> bool:
67
  entity_name_label = node_id.strip('\"')
68
 
69
+ async with self._driver.session() as session:
70
+ query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
71
+ result = await session.run(query)
72
+ single_result = await result.single()
73
  logger.debug(
74
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["node_exists"]}'
75
  )
 
76
  return single_result["node_exists"]
77
+
 
 
 
 
78
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
79
  entity_name_label_source = source_node_id.strip('\"')
80
  entity_name_label_target = target_node_id.strip('\"')
81
 
82
+ async with self._driver.session() as session:
 
83
  query = (
84
+ f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
85
  "RETURN COUNT(r) > 0 AS edgeExists"
86
  )
87
+ result = await session.run(query)
88
+ single_result = await result.single()
89
  logger.debug(
90
  f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result["edgeExists"]}'
91
  )
 
92
  return single_result["edgeExists"]
93
+
94
  def close(self):
95
  self._driver.close()
96
+
 
 
 
97
 
98
 
99
 
100
  async def get_node(self, node_id: str) -> Union[dict, None]:
101
+ async with self._driver.session() as session:
102
+ entity_name_label = node_id.strip('\"')
103
+ query = f"MATCH (n:`{entity_name_label}`) RETURN n"
104
+ result = await session.run(query)
105
+ record = await result.single()
106
+ if record:
107
+ node = record["n"]
108
+ node_dict = dict(node)
109
  logger.debug(
110
+ f'{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}'
111
+ )
112
+ return node_dict
113
+ return None
114
+
115
 
116
 
117
  async def node_degree(self, node_id: str) -> int:
118
  entity_name_label = node_id.strip('\"')
119
 
120
+ async with self._driver.session() as session:
121
+ query = f"""
122
+ MATCH (n:`{entity_name_label}`)
123
+ RETURN COUNT{{ (n)--() }} AS totalEdgeCount
124
+ """
125
+ result = await session.run(query)
126
+ record = await result.single()
127
+ if record:
128
+ edge_count = record["totalEdgeCount"]
129
+ logger.debug(
130
+ f'{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}'
131
+ )
132
+ return edge_count
133
+ else:
134
+ return None
 
 
135
 
 
 
 
 
136
 
137
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
138
  entity_name_label_source = src_id.strip('\"')
139
  entity_name_label_target = tgt_id.strip('\"')
140
+ src_degree = await self.node_degree(entity_name_label_source)
141
+ trg_degree = await self.node_degree(entity_name_label_target)
142
+
143
+ # Convert None to 0 for addition
144
+ src_degree = 0 if src_degree is None else src_degree
145
+ trg_degree = 0 if trg_degree is None else trg_degree
146
+
147
+ degrees = int(src_degree) + int(trg_degree)
148
+ logger.debug(
149
+ f'{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}'
150
+ )
151
+ return degrees
152
+
153
+
154
 
155
  async def get_edge(self, source_node_id: str, target_node_id: str) -> Union[dict, None]:
156
  entity_name_label_source = source_node_id.strip('\"')
 
165
  Returns:
166
  list: List of all relationships/edges found
167
  """
168
+ async with self._driver.session() as session:
169
  query = f"""
170
  MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
171
  RETURN properties(r) as edge_properties
172
  LIMIT 1
173
  """.format(entity_name_label_source=entity_name_label_source, entity_name_label_target=entity_name_label_target)
174
 
175
+ result = await session.run(query)
176
+ record = await result.single()
177
  if record:
178
  result = dict(record["edge_properties"])
179
  logger.debug(
 
184
  return None
185
 
186
 
187
+ async def get_node_edges(self, source_node_id: str)-> List[Tuple[str, str]]:
188
  node_label = source_node_id.strip('\"')
189
 
190
  """
191
+ Retrieves all edges (relationships) for a particular node identified by its label.
 
 
 
 
 
 
192
  :return: List of dictionaries containing edge information
193
  """
194
+ query = f"""MATCH (n:`{node_label}`)
 
 
195
  OPTIONAL MATCH (n)-[r]-(connected)
196
  RETURN n, r, connected"""
197
+ async with self._driver.session() as session:
198
+ results = await session.run(query)
 
199
  edges = []
200
+ async for record in results:
201
  source_node = record['n']
202
  connected_node = record['connected']
203
 
 
209
 
210
  return edges
211
 
212
+ async with self._driver.session() as session:
213
  edges = session.read_transaction(fetch_edges,node_label)
214
  return edges
215
 
 
219
  wait=wait_exponential(multiplier=1, min=4, max=10),
220
  retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
221
  )
222
+ async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
 
 
223
  """
224
+ Upsert a node in the Neo4j database.
225
+
226
  Args:
227
+ node_id: The unique identifier for the node (used as label)
228
+ node_data: Dictionary of node properties
 
 
 
229
  """
230
+ label = node_id.strip('\"')
231
+ properties = node_data
 
 
 
 
 
 
 
 
 
232
 
233
+ async def _do_upsert(tx: AsyncManagedTransaction):
234
  query = f"""
235
  MERGE (n:`{label}`)
236
  SET n += $properties
 
237
  """
238
+ await tx.run(query, properties=properties)
239
+ logger.debug(f"Upserted node with label '{label}' and properties: {properties}")
240
+
241
+ try:
242
+ async with self._driver.session() as session:
243
+ await session.execute_write(_do_upsert)
244
+ except Exception as e:
245
+ logger.error(f"Error during upsert: {str(e)}")
246
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
+ @retry(
249
+ stop=stop_after_attempt(3),
250
+ wait=wait_exponential(multiplier=1, min=4, max=10),
251
+ retry=retry_if_exception_type((neo4jExceptions.ServiceUnavailable, neo4jExceptions.TransientError, neo4jExceptions.WriteServiceUnavailable)),
252
+ )
253
+ async def upsert_edge(self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]):
254
  """
255
  Upsert an edge and its properties between two nodes identified by their labels.
256
+
257
  Args:
258
+ source_node_id (str): Label of the source node (used as identifier)
259
+ target_node_id (str): Label of the target node (used as identifier)
260
+ edge_data (dict): Dictionary of properties to set on the edge
261
  """
262
+ source_node_label = source_node_id.strip('\"')
263
+ target_node_label = target_node_id.strip('\"')
264
+ edge_properties = edge_data
 
 
 
 
 
 
 
 
 
 
 
265
 
266
+ async def _do_upsert_edge(tx: AsyncManagedTransaction):
267
  query = f"""
268
  MATCH (source:`{source_node_label}`)
269
  WITH source
 
272
  SET r += $properties
273
  RETURN r
274
  """
275
+ await tx.run(query, properties=edge_properties)
276
+ logger.debug(f"Upserted edge from '{source_node_label}' to '{target_node_label}' with properties: {edge_properties}")
277
+
278
+ try:
279
+ async with self._driver.session() as session:
280
+ await session.execute_write(_do_upsert_edge)
281
+ except Exception as e:
282
+ logger.error(f"Error during edge upsert: {str(e)}")
283
+ raise
 
 
 
 
 
 
 
284
  async def _node2vec_embed(self):
285
  print ("Implemented but never called.")
286
 
lightrag/lightrag.py CHANGED
@@ -26,7 +26,7 @@ from .storage import (
26
  )
27
 
28
  from .kg.neo4j_impl import (
29
- GraphStorage as Neo4JStorage
30
  )
31
  #future KG integrations
32
 
@@ -57,9 +57,10 @@ def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
57
  try:
58
  loop = asyncio.get_running_loop()
59
  except RuntimeError:
60
- logger.info("Creating a new event loop in a sub-thread.")
61
- loop = asyncio.new_event_loop()
62
- asyncio.set_event_loop(loop)
 
63
  return loop
64
 
65
 
@@ -329,4 +330,4 @@ class LightRAG:
329
  if storage_inst is None:
330
  continue
331
  tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
332
- await asyncio.gather(*tasks)
 
26
  )
27
 
28
  from .kg.neo4j_impl import (
29
+ Neo4JStorage
30
  )
31
  #future KG integrations
32
 
 
57
  try:
58
  loop = asyncio.get_running_loop()
59
  except RuntimeError:
60
+ logger.info("Creating a new event loop in main thread.")
61
+ # loop = asyncio.new_event_loop()
62
+ # asyncio.set_event_loop(loop)
63
+ loop = asyncio.get_event_loop()
64
  return loop
65
 
66
 
 
330
  if storage_inst is None:
331
  continue
332
  tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
333
+ await asyncio.gather(*tasks)
lightrag/llm.py CHANGED
@@ -798,4 +798,4 @@ if __name__ == "__main__":
798
  result = await gpt_4o_mini_complete("How are you?")
799
  print(result)
800
 
801
- asyncio.run(main())
 
798
  result = await gpt_4o_mini_complete("How are you?")
799
  print(result)
800
 
801
+ asyncio.run(main())
lightrag/operate.py CHANGED
@@ -1083,4 +1083,4 @@ async def naive_query(
1083
  .strip()
1084
  )
1085
 
1086
- return response
 
1083
  .strip()
1084
  )
1085
 
1086
+ return response
test_neo4j.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  from lightrag import LightRAG, QueryParam
3
  from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
4
 
 
5
  #########
6
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
7
  # import nest_asyncio
 
2
  from lightrag import LightRAG, QueryParam
3
  from lightrag.llm import gpt_4o_mini_complete, gpt_4o_complete
4
 
5
+
6
  #########
7
  # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
8
  # import nest_asyncio