Samuel Chan commited on
Commit
b57ed74
·
unverified ·
2 Parent(s): b6db833 09b6026

Merge branch 'HKUDS:main' into main

Browse files
Files changed (1) hide show
  1. lightrag/kg/neo4j_impl.py +52 -14
lightrag/kg/neo4j_impl.py CHANGED
@@ -1,18 +1,16 @@
1
  import asyncio
 
2
  import os
3
  from dataclasses import dataclass
4
  from typing import Any, Union, Tuple, List, Dict
5
- import inspect
6
- from lightrag.utils import logger
7
- from ..base import BaseGraphStorage
8
  from neo4j import (
9
  AsyncGraphDatabase,
10
  exceptions as neo4jExceptions,
11
  AsyncDriver,
12
  AsyncManagedTransaction,
 
13
  )
14
-
15
-
16
  from tenacity import (
17
  retry,
18
  stop_after_attempt,
@@ -20,6 +18,9 @@ from tenacity import (
20
  retry_if_exception_type,
21
  )
22
 
 
 
 
23
 
24
  @dataclass
25
  class Neo4JStorage(BaseGraphStorage):
@@ -38,10 +39,47 @@ class Neo4JStorage(BaseGraphStorage):
38
  URI = os.environ["NEO4J_URI"]
39
  USERNAME = os.environ["NEO4J_USERNAME"]
40
  PASSWORD = os.environ["NEO4J_PASSWORD"]
 
 
 
 
41
  self._driver: AsyncDriver = AsyncGraphDatabase.driver(
42
  URI, auth=(USERNAME, PASSWORD)
43
  )
44
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  def __post_init__(self):
47
  self._node_embed_algorithms = {
@@ -63,7 +101,7 @@ class Neo4JStorage(BaseGraphStorage):
63
  async def has_node(self, node_id: str) -> bool:
64
  entity_name_label = node_id.strip('"')
65
 
66
- async with self._driver.session() as session:
67
  query = (
68
  f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
69
  )
@@ -78,7 +116,7 @@ class Neo4JStorage(BaseGraphStorage):
78
  entity_name_label_source = source_node_id.strip('"')
79
  entity_name_label_target = target_node_id.strip('"')
80
 
81
- async with self._driver.session() as session:
82
  query = (
83
  f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
84
  "RETURN COUNT(r) > 0 AS edgeExists"
@@ -91,7 +129,7 @@ class Neo4JStorage(BaseGraphStorage):
91
  return single_result["edgeExists"]
92
 
93
  async def get_node(self, node_id: str) -> Union[dict, None]:
94
- async with self._driver.session() as session:
95
  entity_name_label = node_id.strip('"')
96
  query = f"MATCH (n:`{entity_name_label}`) RETURN n"
97
  result = await session.run(query)
@@ -108,7 +146,7 @@ class Neo4JStorage(BaseGraphStorage):
108
  async def node_degree(self, node_id: str) -> int:
109
  entity_name_label = node_id.strip('"')
110
 
111
- async with self._driver.session() as session:
112
  query = f"""
113
  MATCH (n:`{entity_name_label}`)
114
  RETURN COUNT{{ (n)--() }} AS totalEdgeCount
@@ -155,7 +193,7 @@ class Neo4JStorage(BaseGraphStorage):
155
  Returns:
156
  list: List of all relationships/edges found
157
  """
158
- async with self._driver.session() as session:
159
  query = f"""
160
  MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
161
  RETURN properties(r) as edge_properties
@@ -186,7 +224,7 @@ class Neo4JStorage(BaseGraphStorage):
186
  query = f"""MATCH (n:`{node_label}`)
187
  OPTIONAL MATCH (n)-[r]-(connected)
188
  RETURN n, r, connected"""
189
- async with self._driver.session() as session:
190
  results = await session.run(query)
191
  edges = []
192
  async for record in results:
@@ -241,7 +279,7 @@ class Neo4JStorage(BaseGraphStorage):
241
  )
242
 
243
  try:
244
- async with self._driver.session() as session:
245
  await session.execute_write(_do_upsert)
246
  except Exception as e:
247
  logger.error(f"Error during upsert: {str(e)}")
@@ -288,7 +326,7 @@ class Neo4JStorage(BaseGraphStorage):
288
  )
289
 
290
  try:
291
- async with self._driver.session() as session:
292
  await session.execute_write(_do_upsert_edge)
293
  except Exception as e:
294
  logger.error(f"Error during edge upsert: {str(e)}")
 
1
  import asyncio
2
+ import inspect
3
  import os
4
  from dataclasses import dataclass
5
  from typing import Any, Union, Tuple, List, Dict
6
+
 
 
7
  from neo4j import (
8
  AsyncGraphDatabase,
9
  exceptions as neo4jExceptions,
10
  AsyncDriver,
11
  AsyncManagedTransaction,
12
+ GraphDatabase,
13
  )
 
 
14
  from tenacity import (
15
  retry,
16
  stop_after_attempt,
 
18
  retry_if_exception_type,
19
  )
20
 
21
+ from lightrag.utils import logger
22
+ from ..base import BaseGraphStorage
23
+
24
 
25
  @dataclass
26
  class Neo4JStorage(BaseGraphStorage):
 
39
  URI = os.environ["NEO4J_URI"]
40
  USERNAME = os.environ["NEO4J_USERNAME"]
41
  PASSWORD = os.environ["NEO4J_PASSWORD"]
42
+ DATABASE = os.environ.get(
43
+ "NEO4J_DATABASE"
44
+ ) # If this param is None, the home database will be used. If it is not None, the specified database will be used.
45
+ self._DATABASE = DATABASE
46
  self._driver: AsyncDriver = AsyncGraphDatabase.driver(
47
  URI, auth=(USERNAME, PASSWORD)
48
  )
49
+ _database_name = "home database" if DATABASE is None else f"database {DATABASE}"
50
+ with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) as _sync_driver:
51
+ try:
52
+ with _sync_driver.session(database=DATABASE) as session:
53
+ try:
54
+ session.run("MATCH (n) RETURN n LIMIT 0")
55
+ logger.info(f"Connected to {DATABASE} at {URI}")
56
+ except neo4jExceptions.ServiceUnavailable as e:
57
+ logger.error(
58
+ f"{DATABASE} at {URI} is not available".capitalize()
59
+ )
60
+ raise e
61
+ except neo4jExceptions.AuthError as e:
62
+ logger.error(f"Authentication failed for {DATABASE} at {URI}")
63
+ raise e
64
+ except neo4jExceptions.ClientError as e:
65
+ if e.code == "Neo.ClientError.Database.DatabaseNotFound":
66
+ logger.info(
67
+ f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize()
68
+ )
69
+ try:
70
+ with _sync_driver.session() as session:
71
+ session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS")
72
+ logger.info(f"{DATABASE} at {URI} created".capitalize())
73
+ except neo4jExceptions.ClientError as e:
74
+ if (
75
+ e.code
76
+ == "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
77
+ ):
78
+ logger.warning(
79
+ "This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead."
80
+ )
81
+ logger.error(f"Failed to create {DATABASE} at {URI}")
82
+ raise e
83
 
84
  def __post_init__(self):
85
  self._node_embed_algorithms = {
 
101
  async def has_node(self, node_id: str) -> bool:
102
  entity_name_label = node_id.strip('"')
103
 
104
+ async with self._driver.session(database=self._DATABASE) as session:
105
  query = (
106
  f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
107
  )
 
116
  entity_name_label_source = source_node_id.strip('"')
117
  entity_name_label_target = target_node_id.strip('"')
118
 
119
+ async with self._driver.session(database=self._DATABASE) as session:
120
  query = (
121
  f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
122
  "RETURN COUNT(r) > 0 AS edgeExists"
 
129
  return single_result["edgeExists"]
130
 
131
  async def get_node(self, node_id: str) -> Union[dict, None]:
132
+ async with self._driver.session(database=self._DATABASE) as session:
133
  entity_name_label = node_id.strip('"')
134
  query = f"MATCH (n:`{entity_name_label}`) RETURN n"
135
  result = await session.run(query)
 
146
  async def node_degree(self, node_id: str) -> int:
147
  entity_name_label = node_id.strip('"')
148
 
149
+ async with self._driver.session(database=self._DATABASE) as session:
150
  query = f"""
151
  MATCH (n:`{entity_name_label}`)
152
  RETURN COUNT{{ (n)--() }} AS totalEdgeCount
 
193
  Returns:
194
  list: List of all relationships/edges found
195
  """
196
+ async with self._driver.session(database=self._DATABASE) as session:
197
  query = f"""
198
  MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
199
  RETURN properties(r) as edge_properties
 
224
  query = f"""MATCH (n:`{node_label}`)
225
  OPTIONAL MATCH (n)-[r]-(connected)
226
  RETURN n, r, connected"""
227
+ async with self._driver.session(database=self._DATABASE) as session:
228
  results = await session.run(query)
229
  edges = []
230
  async for record in results:
 
279
  )
280
 
281
  try:
282
+ async with self._driver.session(database=self._DATABASE) as session:
283
  await session.execute_write(_do_upsert)
284
  except Exception as e:
285
  logger.error(f"Error during upsert: {str(e)}")
 
326
  )
327
 
328
  try:
329
+ async with self._driver.session(database=self._DATABASE) as session:
330
  await session.execute_write(_do_upsert_edge)
331
  except Exception as e:
332
  logger.error(f"Error during edge upsert: {str(e)}")