Merge branch 'HKUDS:main' into main
Browse files- lightrag/kg/neo4j_impl.py +52 -14
lightrag/kg/neo4j_impl.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1 |
import asyncio
|
|
|
2 |
import os
|
3 |
from dataclasses import dataclass
|
4 |
from typing import Any, Union, Tuple, List, Dict
|
5 |
-
|
6 |
-
from lightrag.utils import logger
|
7 |
-
from ..base import BaseGraphStorage
|
8 |
from neo4j import (
|
9 |
AsyncGraphDatabase,
|
10 |
exceptions as neo4jExceptions,
|
11 |
AsyncDriver,
|
12 |
AsyncManagedTransaction,
|
|
|
13 |
)
|
14 |
-
|
15 |
-
|
16 |
from tenacity import (
|
17 |
retry,
|
18 |
stop_after_attempt,
|
@@ -20,6 +18,9 @@ from tenacity import (
|
|
20 |
retry_if_exception_type,
|
21 |
)
|
22 |
|
|
|
|
|
|
|
23 |
|
24 |
@dataclass
|
25 |
class Neo4JStorage(BaseGraphStorage):
|
@@ -38,10 +39,47 @@ class Neo4JStorage(BaseGraphStorage):
|
|
38 |
URI = os.environ["NEO4J_URI"]
|
39 |
USERNAME = os.environ["NEO4J_USERNAME"]
|
40 |
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
|
|
|
|
|
|
|
|
41 |
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
42 |
URI, auth=(USERNAME, PASSWORD)
|
43 |
)
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def __post_init__(self):
|
47 |
self._node_embed_algorithms = {
|
@@ -63,7 +101,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
63 |
async def has_node(self, node_id: str) -> bool:
|
64 |
entity_name_label = node_id.strip('"')
|
65 |
|
66 |
-
async with self._driver.session() as session:
|
67 |
query = (
|
68 |
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
69 |
)
|
@@ -78,7 +116,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
78 |
entity_name_label_source = source_node_id.strip('"')
|
79 |
entity_name_label_target = target_node_id.strip('"')
|
80 |
|
81 |
-
async with self._driver.session() as session:
|
82 |
query = (
|
83 |
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
84 |
"RETURN COUNT(r) > 0 AS edgeExists"
|
@@ -91,7 +129,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
91 |
return single_result["edgeExists"]
|
92 |
|
93 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
94 |
-
async with self._driver.session() as session:
|
95 |
entity_name_label = node_id.strip('"')
|
96 |
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
97 |
result = await session.run(query)
|
@@ -108,7 +146,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
108 |
async def node_degree(self, node_id: str) -> int:
|
109 |
entity_name_label = node_id.strip('"')
|
110 |
|
111 |
-
async with self._driver.session() as session:
|
112 |
query = f"""
|
113 |
MATCH (n:`{entity_name_label}`)
|
114 |
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
|
@@ -155,7 +193,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
155 |
Returns:
|
156 |
list: List of all relationships/edges found
|
157 |
"""
|
158 |
-
async with self._driver.session() as session:
|
159 |
query = f"""
|
160 |
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
161 |
RETURN properties(r) as edge_properties
|
@@ -186,7 +224,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
186 |
query = f"""MATCH (n:`{node_label}`)
|
187 |
OPTIONAL MATCH (n)-[r]-(connected)
|
188 |
RETURN n, r, connected"""
|
189 |
-
async with self._driver.session() as session:
|
190 |
results = await session.run(query)
|
191 |
edges = []
|
192 |
async for record in results:
|
@@ -241,7 +279,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
241 |
)
|
242 |
|
243 |
try:
|
244 |
-
async with self._driver.session() as session:
|
245 |
await session.execute_write(_do_upsert)
|
246 |
except Exception as e:
|
247 |
logger.error(f"Error during upsert: {str(e)}")
|
@@ -288,7 +326,7 @@ class Neo4JStorage(BaseGraphStorage):
|
|
288 |
)
|
289 |
|
290 |
try:
|
291 |
-
async with self._driver.session() as session:
|
292 |
await session.execute_write(_do_upsert_edge)
|
293 |
except Exception as e:
|
294 |
logger.error(f"Error during edge upsert: {str(e)}")
|
|
|
1 |
import asyncio
|
2 |
+
import inspect
|
3 |
import os
|
4 |
from dataclasses import dataclass
|
5 |
from typing import Any, Union, Tuple, List, Dict
|
6 |
+
|
|
|
|
|
7 |
from neo4j import (
|
8 |
AsyncGraphDatabase,
|
9 |
exceptions as neo4jExceptions,
|
10 |
AsyncDriver,
|
11 |
AsyncManagedTransaction,
|
12 |
+
GraphDatabase,
|
13 |
)
|
|
|
|
|
14 |
from tenacity import (
|
15 |
retry,
|
16 |
stop_after_attempt,
|
|
|
18 |
retry_if_exception_type,
|
19 |
)
|
20 |
|
21 |
+
from lightrag.utils import logger
|
22 |
+
from ..base import BaseGraphStorage
|
23 |
+
|
24 |
|
25 |
@dataclass
|
26 |
class Neo4JStorage(BaseGraphStorage):
|
|
|
39 |
URI = os.environ["NEO4J_URI"]
|
40 |
USERNAME = os.environ["NEO4J_USERNAME"]
|
41 |
PASSWORD = os.environ["NEO4J_PASSWORD"]
|
42 |
+
DATABASE = os.environ.get(
|
43 |
+
"NEO4J_DATABASE"
|
44 |
+
) # If this param is None, the home database will be used. If it is not None, the specified database will be used.
|
45 |
+
self._DATABASE = DATABASE
|
46 |
self._driver: AsyncDriver = AsyncGraphDatabase.driver(
|
47 |
URI, auth=(USERNAME, PASSWORD)
|
48 |
)
|
49 |
+
_database_name = "home database" if DATABASE is None else f"database {DATABASE}"
|
50 |
+
with GraphDatabase.driver(URI, auth=(USERNAME, PASSWORD)) as _sync_driver:
|
51 |
+
try:
|
52 |
+
with _sync_driver.session(database=DATABASE) as session:
|
53 |
+
try:
|
54 |
+
session.run("MATCH (n) RETURN n LIMIT 0")
|
55 |
+
logger.info(f"Connected to {DATABASE} at {URI}")
|
56 |
+
except neo4jExceptions.ServiceUnavailable as e:
|
57 |
+
logger.error(
|
58 |
+
f"{DATABASE} at {URI} is not available".capitalize()
|
59 |
+
)
|
60 |
+
raise e
|
61 |
+
except neo4jExceptions.AuthError as e:
|
62 |
+
logger.error(f"Authentication failed for {DATABASE} at {URI}")
|
63 |
+
raise e
|
64 |
+
except neo4jExceptions.ClientError as e:
|
65 |
+
if e.code == "Neo.ClientError.Database.DatabaseNotFound":
|
66 |
+
logger.info(
|
67 |
+
f"{DATABASE} at {URI} not found. Try to create specified database.".capitalize()
|
68 |
+
)
|
69 |
+
try:
|
70 |
+
with _sync_driver.session() as session:
|
71 |
+
session.run(f"CREATE DATABASE `{DATABASE}` IF NOT EXISTS")
|
72 |
+
logger.info(f"{DATABASE} at {URI} created".capitalize())
|
73 |
+
except neo4jExceptions.ClientError as e:
|
74 |
+
if (
|
75 |
+
e.code
|
76 |
+
== "Neo.ClientError.Statement.UnsupportedAdministrationCommand"
|
77 |
+
):
|
78 |
+
logger.warning(
|
79 |
+
"This Neo4j instance does not support creating databases. Try to use Neo4j Desktop/Enterprise version or DozerDB instead."
|
80 |
+
)
|
81 |
+
logger.error(f"Failed to create {DATABASE} at {URI}")
|
82 |
+
raise e
|
83 |
|
84 |
def __post_init__(self):
|
85 |
self._node_embed_algorithms = {
|
|
|
101 |
async def has_node(self, node_id: str) -> bool:
|
102 |
entity_name_label = node_id.strip('"')
|
103 |
|
104 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
105 |
query = (
|
106 |
f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
|
107 |
)
|
|
|
116 |
entity_name_label_source = source_node_id.strip('"')
|
117 |
entity_name_label_target = target_node_id.strip('"')
|
118 |
|
119 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
120 |
query = (
|
121 |
f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
|
122 |
"RETURN COUNT(r) > 0 AS edgeExists"
|
|
|
129 |
return single_result["edgeExists"]
|
130 |
|
131 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
132 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
133 |
entity_name_label = node_id.strip('"')
|
134 |
query = f"MATCH (n:`{entity_name_label}`) RETURN n"
|
135 |
result = await session.run(query)
|
|
|
146 |
async def node_degree(self, node_id: str) -> int:
|
147 |
entity_name_label = node_id.strip('"')
|
148 |
|
149 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
150 |
query = f"""
|
151 |
MATCH (n:`{entity_name_label}`)
|
152 |
RETURN COUNT{{ (n)--() }} AS totalEdgeCount
|
|
|
193 |
Returns:
|
194 |
list: List of all relationships/edges found
|
195 |
"""
|
196 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
197 |
query = f"""
|
198 |
MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
|
199 |
RETURN properties(r) as edge_properties
|
|
|
224 |
query = f"""MATCH (n:`{node_label}`)
|
225 |
OPTIONAL MATCH (n)-[r]-(connected)
|
226 |
RETURN n, r, connected"""
|
227 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
228 |
results = await session.run(query)
|
229 |
edges = []
|
230 |
async for record in results:
|
|
|
279 |
)
|
280 |
|
281 |
try:
|
282 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
283 |
await session.execute_write(_do_upsert)
|
284 |
except Exception as e:
|
285 |
logger.error(f"Error during upsert: {str(e)}")
|
|
|
326 |
)
|
327 |
|
328 |
try:
|
329 |
+
async with self._driver.session(database=self._DATABASE) as session:
|
330 |
await session.execute_write(_do_upsert_edge)
|
331 |
except Exception as e:
|
332 |
logger.error(f"Error during edge upsert: {str(e)}")
|