Merge pull request #505 from alllexx88/arcadedb_gremlin_graph
Browse files- lightrag/kg/gremlin_impl.py +106 -127
lightrag/kg/gremlin_impl.py
CHANGED
@@ -2,7 +2,6 @@ import asyncio
|
|
2 |
import inspect
|
3 |
import json
|
4 |
import os
|
5 |
-
import re
|
6 |
from dataclasses import dataclass
|
7 |
from typing import Any, Dict, List, Tuple, Union
|
8 |
|
@@ -27,9 +26,6 @@ class GremlinStorage(BaseGraphStorage):
|
|
27 |
def load_nx_graph(file_name):
|
28 |
print("no preloading of graph with Gremlin in production")
|
29 |
|
30 |
-
# Will use this to make sure single quotes are properly escaped
|
31 |
-
escape_rx = re.compile(r"(^|[^\\])((\\\\)*\\)\\'")
|
32 |
-
|
33 |
def __init__(self, namespace, global_config, embedding_func):
|
34 |
super().__init__(
|
35 |
namespace=namespace,
|
@@ -51,12 +47,8 @@ class GremlinStorage(BaseGraphStorage):
|
|
51 |
|
52 |
# All vertices will have graph={GRAPH} property, so that we can
|
53 |
# have several logical graphs for one source
|
54 |
-
GRAPH = GremlinStorage.
|
55 |
-
r"\1\2'",
|
56 |
-
os.environ["GREMLIN_GRAPH"].replace("'", r"\'"),
|
57 |
-
)
|
58 |
|
59 |
-
self.traverse_source_name = SOURCE
|
60 |
self.graph_name = GRAPH
|
61 |
|
62 |
self._driver = client.Client(
|
@@ -87,7 +79,7 @@ class GremlinStorage(BaseGraphStorage):
|
|
87 |
|
88 |
@staticmethod
|
89 |
def _to_value_map(value: Any) -> str:
|
90 |
-
"""Dump Python
|
91 |
json_str = json.dumps(value, ensure_ascii=False, sort_keys=False)
|
92 |
parsed_str = json_str.replace("'", r"\'")
|
93 |
|
@@ -122,17 +114,16 @@ class GremlinStorage(BaseGraphStorage):
|
|
122 |
"""Create chained .property() commands from properties dict"""
|
123 |
props = []
|
124 |
for k, v in properties.items():
|
125 |
-
prop_name = GremlinStorage.
|
126 |
-
props.append(f".property(
|
127 |
return "".join(props)
|
128 |
|
129 |
@staticmethod
|
130 |
-
def
|
131 |
-
"""Strip double quotes and
|
132 |
-
|
133 |
-
label = GremlinStorage.escape_rx.sub(r"\1\2'", label)
|
134 |
|
135 |
-
return
|
136 |
|
137 |
async def _query(self, query: str) -> List[Dict[str, Any]]:
|
138 |
"""
|
@@ -146,66 +137,69 @@ class GremlinStorage(BaseGraphStorage):
|
|
146 |
"""
|
147 |
|
148 |
result = list(await asyncio.wrap_future(self._driver.submit_async(query)))
|
|
|
|
|
149 |
|
150 |
return result
|
151 |
|
152 |
async def has_node(self, node_id: str) -> bool:
|
153 |
-
|
154 |
|
155 |
-
query = f"""
|
156 |
-
{self.
|
157 |
-
.
|
158 |
-
.hasLabel('{entity_name_label}')
|
159 |
.limit(1)
|
160 |
-
.
|
|
|
|
|
161 |
"""
|
162 |
result = await self._query(query)
|
163 |
logger.debug(
|
164 |
"{%s}:query:{%s}:result:{%s}",
|
165 |
inspect.currentframe().f_code.co_name,
|
166 |
query,
|
167 |
-
result[0][
|
168 |
)
|
169 |
|
170 |
-
return result[0][
|
171 |
|
172 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
query = f"""
|
177 |
-
{self.
|
178 |
-
.
|
179 |
-
.
|
180 |
-
.
|
181 |
-
.
|
182 |
-
.hasLabel('{entity_name_label_target}')
|
183 |
.limit(1)
|
184 |
-
.
|
|
|
|
|
185 |
"""
|
186 |
result = await self._query(query)
|
187 |
logger.debug(
|
188 |
"{%s}:query:{%s}:result:{%s}",
|
189 |
inspect.currentframe().f_code.co_name,
|
190 |
query,
|
191 |
-
result[0][
|
192 |
)
|
193 |
|
194 |
-
return result[0][
|
195 |
|
196 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
197 |
-
|
198 |
-
query = f"""
|
199 |
-
{self.
|
200 |
-
.
|
201 |
-
.hasLabel('{entity_name_label}')
|
202 |
.limit(1)
|
203 |
.project('properties')
|
204 |
.by(elementMap())
|
205 |
"""
|
206 |
result = await self._query(query)
|
207 |
if result:
|
208 |
-
node = result[0]
|
209 |
node_dict = node["properties"]
|
210 |
logger.debug(
|
211 |
"{%s}: query: {%s}, result: {%s}",
|
@@ -216,19 +210,18 @@ class GremlinStorage(BaseGraphStorage):
|
|
216 |
return node_dict
|
217 |
|
218 |
async def node_degree(self, node_id: str) -> int:
|
219 |
-
|
220 |
-
query = f"""
|
221 |
-
{self.
|
222 |
-
.
|
223 |
-
.hasLabel('{entity_name_label}')
|
224 |
.outE()
|
225 |
-
.inV().has('graph',
|
226 |
.count()
|
227 |
.project('total_edge_count')
|
228 |
.by()
|
229 |
"""
|
230 |
result = await self._query(query)
|
231 |
-
edge_count = result[0][
|
232 |
|
233 |
logger.debug(
|
234 |
"{%s}:query:{%s}:result:{%s}",
|
@@ -259,31 +252,30 @@ class GremlinStorage(BaseGraphStorage):
|
|
259 |
self, source_node_id: str, target_node_id: str
|
260 |
) -> Union[dict, None]:
|
261 |
"""
|
262 |
-
Find all edges between nodes of two given
|
263 |
|
264 |
Args:
|
265 |
-
|
266 |
-
|
267 |
|
268 |
Returns:
|
269 |
-
dict|None: Dict of found edge properties, or None
|
270 |
"""
|
271 |
-
|
272 |
-
|
273 |
-
query = f"""
|
274 |
-
{self.
|
275 |
-
.
|
276 |
-
.hasLabel('{entity_name_label_source}')
|
277 |
.outE()
|
278 |
-
.inV().has('graph',
|
279 |
-
.
|
280 |
.limit(1)
|
281 |
.project('edge_properties')
|
282 |
.by(__.bothE().elementMap())
|
283 |
"""
|
284 |
result = await self._query(query)
|
285 |
if result:
|
286 |
-
edge_properties = result[0][
|
287 |
logger.debug(
|
288 |
"{%s}:query:{%s}:result:{%s}",
|
289 |
inspect.currentframe().f_code.co_name,
|
@@ -294,45 +286,31 @@ class GremlinStorage(BaseGraphStorage):
|
|
294 |
|
295 |
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
296 |
"""
|
297 |
-
Retrieves all edges (relationships) for a particular node identified by its
|
298 |
:return: List of tuples containing edge sources and targets
|
299 |
"""
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
.
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
.
|
312 |
-
.
|
313 |
-
.
|
314 |
-
.hasLabel('{node_label}')
|
315 |
-
.project('connected_label')
|
316 |
-
.by(__.select('connected').label())
|
317 |
"""
|
318 |
-
|
319 |
-
|
320 |
-
)
|
321 |
-
edges1 = (
|
322 |
-
[(node_label, res["connected_label"]) for res in result1[0]]
|
323 |
-
if result1
|
324 |
-
else []
|
325 |
-
)
|
326 |
-
edges2 = (
|
327 |
-
[(res["connected_label"], node_label) for res in result2[0]]
|
328 |
-
if result2
|
329 |
-
else []
|
330 |
-
)
|
331 |
|
332 |
-
return
|
333 |
|
334 |
@retry(
|
335 |
-
stop=stop_after_attempt(
|
336 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
337 |
retry=retry_if_exception_type((GremlinServerError,)),
|
338 |
)
|
@@ -341,28 +319,30 @@ class GremlinStorage(BaseGraphStorage):
|
|
341 |
Upsert a node in the Gremlin graph.
|
342 |
|
343 |
Args:
|
344 |
-
node_id: The unique identifier for the node (used as
|
345 |
node_data: Dictionary of node properties
|
346 |
"""
|
347 |
-
|
348 |
properties = GremlinStorage._convert_properties(node_data)
|
349 |
|
350 |
-
query = f"""
|
351 |
-
{self.
|
352 |
-
.
|
353 |
-
.
|
354 |
.coalesce(
|
355 |
-
unfold(),
|
356 |
-
addV('
|
357 |
-
|
|
|
|
|
358 |
{properties}
|
359 |
"""
|
360 |
|
361 |
try:
|
362 |
await self._query(query)
|
363 |
logger.debug(
|
364 |
-
"Upserted node with
|
365 |
-
|
366 |
properties,
|
367 |
)
|
368 |
except Exception as e:
|
@@ -370,7 +350,7 @@ class GremlinStorage(BaseGraphStorage):
|
|
370 |
raise
|
371 |
|
372 |
@retry(
|
373 |
-
stop=stop_after_attempt(
|
374 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
375 |
retry=retry_if_exception_type((GremlinServerError,)),
|
376 |
)
|
@@ -378,36 +358,35 @@ class GremlinStorage(BaseGraphStorage):
|
|
378 |
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
379 |
):
|
380 |
"""
|
381 |
-
Upsert an edge and its properties between two nodes identified by their
|
382 |
|
383 |
Args:
|
384 |
-
source_node_id (str):
|
385 |
-
target_node_id (str):
|
386 |
edge_data (dict): Dictionary of properties to set on the edge
|
387 |
"""
|
388 |
-
|
389 |
-
|
390 |
edge_properties = GremlinStorage._convert_properties(edge_data)
|
391 |
|
392 |
-
query = f"""
|
393 |
-
{self.
|
394 |
-
.
|
395 |
-
.
|
396 |
-
.
|
397 |
-
.hasLabel('{target_node_label}').as('target')
|
398 |
.coalesce(
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
{edge_properties}
|
404 |
"""
|
405 |
try:
|
406 |
await self._query(query)
|
407 |
logger.debug(
|
408 |
-
"Upserted edge from
|
409 |
-
|
410 |
-
|
411 |
edge_properties,
|
412 |
)
|
413 |
except Exception as e:
|
|
|
2 |
import inspect
|
3 |
import json
|
4 |
import os
|
|
|
5 |
from dataclasses import dataclass
|
6 |
from typing import Any, Dict, List, Tuple, Union
|
7 |
|
|
|
26 |
def load_nx_graph(file_name):
|
27 |
print("no preloading of graph with Gremlin in production")
|
28 |
|
|
|
|
|
|
|
29 |
def __init__(self, namespace, global_config, embedding_func):
|
30 |
super().__init__(
|
31 |
namespace=namespace,
|
|
|
47 |
|
48 |
# All vertices will have graph={GRAPH} property, so that we can
|
49 |
# have several logical graphs for one source
|
50 |
+
GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"])
|
|
|
|
|
|
|
51 |
|
|
|
52 |
self.graph_name = GRAPH
|
53 |
|
54 |
self._driver = client.Client(
|
|
|
79 |
|
80 |
@staticmethod
|
81 |
def _to_value_map(value: Any) -> str:
|
82 |
+
"""Dump supported Python object as Gremlin valueMap"""
|
83 |
json_str = json.dumps(value, ensure_ascii=False, sort_keys=False)
|
84 |
parsed_str = json_str.replace("'", r"\'")
|
85 |
|
|
|
114 |
"""Create chained .property() commands from properties dict"""
|
115 |
props = []
|
116 |
for k, v in properties.items():
|
117 |
+
prop_name = GremlinStorage._to_value_map(k)
|
118 |
+
props.append(f".property({prop_name}, {GremlinStorage._to_value_map(v)})")
|
119 |
return "".join(props)
|
120 |
|
121 |
@staticmethod
|
122 |
+
def _fix_name(name: str) -> str:
|
123 |
+
"""Strip double quotes and format as a proper field name"""
|
124 |
+
name = GremlinStorage._to_value_map(name.strip('"').replace(r"\'", "'"))
|
|
|
125 |
|
126 |
+
return name
|
127 |
|
128 |
async def _query(self, query: str) -> List[Dict[str, Any]]:
|
129 |
"""
|
|
|
137 |
"""
|
138 |
|
139 |
result = list(await asyncio.wrap_future(self._driver.submit_async(query)))
|
140 |
+
if result:
|
141 |
+
result = result[0]
|
142 |
|
143 |
return result
|
144 |
|
145 |
async def has_node(self, node_id: str) -> bool:
|
146 |
+
entity_name = GremlinStorage._fix_name(node_id)
|
147 |
|
148 |
+
query = f"""g
|
149 |
+
.V().has('graph', {self.graph_name})
|
150 |
+
.has('entity_name', {entity_name})
|
|
|
151 |
.limit(1)
|
152 |
+
.count()
|
153 |
+
.project('has_node')
|
154 |
+
.by(__.choose(__.is(gt(0)), constant(true), constant(false)))
|
155 |
"""
|
156 |
result = await self._query(query)
|
157 |
logger.debug(
|
158 |
"{%s}:query:{%s}:result:{%s}",
|
159 |
inspect.currentframe().f_code.co_name,
|
160 |
query,
|
161 |
+
result[0]["has_node"],
|
162 |
)
|
163 |
|
164 |
+
return result[0]["has_node"]
|
165 |
|
166 |
async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
|
167 |
+
entity_name_source = GremlinStorage._fix_name(source_node_id)
|
168 |
+
entity_name_target = GremlinStorage._fix_name(target_node_id)
|
169 |
+
|
170 |
+
query = f"""g
|
171 |
+
.V().has('graph', {self.graph_name})
|
172 |
+
.has('entity_name', {entity_name_source})
|
173 |
+
.outE()
|
174 |
+
.inV().has('graph', {self.graph_name})
|
175 |
+
.has('entity_name', {entity_name_target})
|
|
|
176 |
.limit(1)
|
177 |
+
.count()
|
178 |
+
.project('has_edge')
|
179 |
+
.by(__.choose(__.is(gt(0)), constant(true), constant(false)))
|
180 |
"""
|
181 |
result = await self._query(query)
|
182 |
logger.debug(
|
183 |
"{%s}:query:{%s}:result:{%s}",
|
184 |
inspect.currentframe().f_code.co_name,
|
185 |
query,
|
186 |
+
result[0]["has_edge"],
|
187 |
)
|
188 |
|
189 |
+
return result[0]["has_edge"]
|
190 |
|
191 |
async def get_node(self, node_id: str) -> Union[dict, None]:
|
192 |
+
entity_name = GremlinStorage._fix_name(node_id)
|
193 |
+
query = f"""g
|
194 |
+
.V().has('graph', {self.graph_name})
|
195 |
+
.has('entity_name', {entity_name})
|
|
|
196 |
.limit(1)
|
197 |
.project('properties')
|
198 |
.by(elementMap())
|
199 |
"""
|
200 |
result = await self._query(query)
|
201 |
if result:
|
202 |
+
node = result[0]
|
203 |
node_dict = node["properties"]
|
204 |
logger.debug(
|
205 |
"{%s}: query: {%s}, result: {%s}",
|
|
|
210 |
return node_dict
|
211 |
|
212 |
async def node_degree(self, node_id: str) -> int:
|
213 |
+
entity_name = GremlinStorage._fix_name(node_id)
|
214 |
+
query = f"""g
|
215 |
+
.V().has('graph', {self.graph_name})
|
216 |
+
.has('entity_name', {entity_name})
|
|
|
217 |
.outE()
|
218 |
+
.inV().has('graph', {self.graph_name})
|
219 |
.count()
|
220 |
.project('total_edge_count')
|
221 |
.by()
|
222 |
"""
|
223 |
result = await self._query(query)
|
224 |
+
edge_count = result[0]["total_edge_count"]
|
225 |
|
226 |
logger.debug(
|
227 |
"{%s}:query:{%s}:result:{%s}",
|
|
|
252 |
self, source_node_id: str, target_node_id: str
|
253 |
) -> Union[dict, None]:
|
254 |
"""
|
255 |
+
Find all edges between nodes of two given names
|
256 |
|
257 |
Args:
|
258 |
+
source_node_id (str): Name of the source nodes
|
259 |
+
target_node_id (str): Name of the target nodes
|
260 |
|
261 |
Returns:
|
262 |
+
dict|None: Dict of found edge properties, or None if not found
|
263 |
"""
|
264 |
+
entity_name_source = GremlinStorage._fix_name(source_node_id)
|
265 |
+
entity_name_target = GremlinStorage._fix_name(target_node_id)
|
266 |
+
query = f"""g
|
267 |
+
.V().has('graph', {self.graph_name})
|
268 |
+
.has('entity_name', {entity_name_source})
|
|
|
269 |
.outE()
|
270 |
+
.inV().has('graph', {self.graph_name})
|
271 |
+
.has('entity_name', {entity_name_target})
|
272 |
.limit(1)
|
273 |
.project('edge_properties')
|
274 |
.by(__.bothE().elementMap())
|
275 |
"""
|
276 |
result = await self._query(query)
|
277 |
if result:
|
278 |
+
edge_properties = result[0]["edge_properties"]
|
279 |
logger.debug(
|
280 |
"{%s}:query:{%s}:result:{%s}",
|
281 |
inspect.currentframe().f_code.co_name,
|
|
|
286 |
|
287 |
async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
|
288 |
"""
|
289 |
+
Retrieves all edges (relationships) for a particular node identified by its name.
|
290 |
:return: List of tuples containing edge sources and targets
|
291 |
"""
|
292 |
+
node_name = GremlinStorage._fix_name(source_node_id)
|
293 |
+
query = f"""g
|
294 |
+
.E()
|
295 |
+
.filter(
|
296 |
+
__.or(
|
297 |
+
__.outV().has('graph', {self.graph_name})
|
298 |
+
.has('entity_name', {node_name}),
|
299 |
+
__.inV().has('graph', {self.graph_name})
|
300 |
+
.has('entity_name', {node_name})
|
301 |
+
)
|
302 |
+
)
|
303 |
+
.project('source_name', 'target_name')
|
304 |
+
.by(__.outV().values('entity_name'))
|
305 |
+
.by(__.inV().values('entity_name'))
|
|
|
|
|
|
|
306 |
"""
|
307 |
+
result = await self._query(query)
|
308 |
+
edges = [(res["source_name"], res["target_name"]) for res in result]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
+
return edges
|
311 |
|
312 |
@retry(
|
313 |
+
stop=stop_after_attempt(10),
|
314 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
315 |
retry=retry_if_exception_type((GremlinServerError,)),
|
316 |
)
|
|
|
319 |
Upsert a node in the Gremlin graph.
|
320 |
|
321 |
Args:
|
322 |
+
node_id: The unique identifier for the node (used as name)
|
323 |
node_data: Dictionary of node properties
|
324 |
"""
|
325 |
+
name = GremlinStorage._fix_name(node_id)
|
326 |
properties = GremlinStorage._convert_properties(node_data)
|
327 |
|
328 |
+
query = f"""g
|
329 |
+
.V().has('graph', {self.graph_name})
|
330 |
+
.has('entity_name', {name})
|
331 |
+
.fold()
|
332 |
.coalesce(
|
333 |
+
__.unfold(),
|
334 |
+
__.addV('ENTITY')
|
335 |
+
.property('graph', {self.graph_name})
|
336 |
+
.property('entity_name', {name})
|
337 |
+
)
|
338 |
{properties}
|
339 |
"""
|
340 |
|
341 |
try:
|
342 |
await self._query(query)
|
343 |
logger.debug(
|
344 |
+
"Upserted node with name {%s} and properties: {%s}",
|
345 |
+
name,
|
346 |
properties,
|
347 |
)
|
348 |
except Exception as e:
|
|
|
350 |
raise
|
351 |
|
352 |
@retry(
|
353 |
+
stop=stop_after_attempt(10),
|
354 |
wait=wait_exponential(multiplier=1, min=4, max=10),
|
355 |
retry=retry_if_exception_type((GremlinServerError,)),
|
356 |
)
|
|
|
358 |
self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
|
359 |
):
|
360 |
"""
|
361 |
+
Upsert an edge and its properties between two nodes identified by their names.
|
362 |
|
363 |
Args:
|
364 |
+
source_node_id (str): Name of the source node (used as identifier)
|
365 |
+
target_node_id (str): Name of the target node (used as identifier)
|
366 |
edge_data (dict): Dictionary of properties to set on the edge
|
367 |
"""
|
368 |
+
source_node_name = GremlinStorage._fix_name(source_node_id)
|
369 |
+
target_node_name = GremlinStorage._fix_name(target_node_id)
|
370 |
edge_properties = GremlinStorage._convert_properties(edge_data)
|
371 |
|
372 |
+
query = f"""g
|
373 |
+
.V().has('graph', {self.graph_name})
|
374 |
+
.has('entity_name', {source_node_name}).as('source')
|
375 |
+
.V().has('graph', {self.graph_name})
|
376 |
+
.has('entity_name', {target_node_name}).as('target')
|
|
|
377 |
.coalesce(
|
378 |
+
__.select('source').outE('DIRECTED').where(__.inV().as('target')),
|
379 |
+
__.select('source').addE('DIRECTED').to(__.select('target'))
|
380 |
+
)
|
381 |
+
.property('graph', {self.graph_name})
|
382 |
{edge_properties}
|
383 |
"""
|
384 |
try:
|
385 |
await self._query(query)
|
386 |
logger.debug(
|
387 |
+
"Upserted edge from {%s} to {%s} with properties: {%s}",
|
388 |
+
source_node_name,
|
389 |
+
target_node_name,
|
390 |
edge_properties,
|
391 |
)
|
392 |
except Exception as e:
|