zrguo commited on
Commit
90740b5
·
unverified ·
2 Parent(s): d9fd40b c57c526

Merge pull request #505 from alllexx88/arcadedb_gremlin_graph

Browse files
Files changed (1) hide show
  1. lightrag/kg/gremlin_impl.py +106 -127
lightrag/kg/gremlin_impl.py CHANGED
@@ -2,7 +2,6 @@ import asyncio
2
  import inspect
3
  import json
4
  import os
5
- import re
6
  from dataclasses import dataclass
7
  from typing import Any, Dict, List, Tuple, Union
8
 
@@ -27,9 +26,6 @@ class GremlinStorage(BaseGraphStorage):
27
  def load_nx_graph(file_name):
28
  print("no preloading of graph with Gremlin in production")
29
 
30
- # Will use this to make sure single quotes are properly escaped
31
- escape_rx = re.compile(r"(^|[^\\])((\\\\)*\\)\\'")
32
-
33
  def __init__(self, namespace, global_config, embedding_func):
34
  super().__init__(
35
  namespace=namespace,
@@ -51,12 +47,8 @@ class GremlinStorage(BaseGraphStorage):
51
 
52
  # All vertices will have graph={GRAPH} property, so that we can
53
  # have several logical graphs for one source
54
- GRAPH = GremlinStorage.escape_rx.sub(
55
- r"\1\2'",
56
- os.environ["GREMLIN_GRAPH"].replace("'", r"\'"),
57
- )
58
 
59
- self.traverse_source_name = SOURCE
60
  self.graph_name = GRAPH
61
 
62
  self._driver = client.Client(
@@ -87,7 +79,7 @@ class GremlinStorage(BaseGraphStorage):
87
 
88
  @staticmethod
89
  def _to_value_map(value: Any) -> str:
90
- """Dump Python dict as Gremlin valueMap"""
91
  json_str = json.dumps(value, ensure_ascii=False, sort_keys=False)
92
  parsed_str = json_str.replace("'", r"\'")
93
 
@@ -122,17 +114,16 @@ class GremlinStorage(BaseGraphStorage):
122
  """Create chained .property() commands from properties dict"""
123
  props = []
124
  for k, v in properties.items():
125
- prop_name = GremlinStorage.escape_rx.sub(r"\1\2'", k.replace("'", r"\'"))
126
- props.append(f".property('{prop_name}', {GremlinStorage._to_value_map(v)})")
127
  return "".join(props)
128
 
129
  @staticmethod
130
- def _fix_label(label: str) -> str:
131
- """Strip double quotes and make sure single quotes are escaped"""
132
- label = label.strip('"').replace("'", r"\'")
133
- label = GremlinStorage.escape_rx.sub(r"\1\2'", label)
134
 
135
- return label
136
 
137
  async def _query(self, query: str) -> List[Dict[str, Any]]:
138
  """
@@ -146,66 +137,69 @@ class GremlinStorage(BaseGraphStorage):
146
  """
147
 
148
  result = list(await asyncio.wrap_future(self._driver.submit_async(query)))
 
 
149
 
150
  return result
151
 
152
  async def has_node(self, node_id: str) -> bool:
153
- entity_name_label = GremlinStorage._fix_label(node_id)
154
 
155
- query = f"""
156
- {self.traverse_source_name}
157
- .V().has('graph', '{self.graph_name}')
158
- .hasLabel('{entity_name_label}')
159
  .limit(1)
160
- .hasNext()
 
 
161
  """
162
  result = await self._query(query)
163
  logger.debug(
164
  "{%s}:query:{%s}:result:{%s}",
165
  inspect.currentframe().f_code.co_name,
166
  query,
167
- result[0][0],
168
  )
169
 
170
- return result[0][0]
171
 
172
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
173
- entity_name_label_source = GremlinStorage._fix_label(source_node_id)
174
- entity_name_label_target = GremlinStorage._fix_label(target_node_id)
175
-
176
- query = f"""
177
- {self.traverse_source_name}
178
- .V().has('graph', '{self.graph_name}')
179
- .hasLabel('{entity_name_label_source}')
180
- .bothE()
181
- .otherV().has('graph', '{self.graph_name}')
182
- .hasLabel('{entity_name_label_target}')
183
  .limit(1)
184
- .hasNext()
 
 
185
  """
186
  result = await self._query(query)
187
  logger.debug(
188
  "{%s}:query:{%s}:result:{%s}",
189
  inspect.currentframe().f_code.co_name,
190
  query,
191
- result[0][0],
192
  )
193
 
194
- return result[0][0]
195
 
196
  async def get_node(self, node_id: str) -> Union[dict, None]:
197
- entity_name_label = GremlinStorage._fix_label(node_id)
198
- query = f"""
199
- {self.traverse_source_name}
200
- .V().has('graph', '{self.graph_name}')
201
- .hasLabel('{entity_name_label}')
202
  .limit(1)
203
  .project('properties')
204
  .by(elementMap())
205
  """
206
  result = await self._query(query)
207
  if result:
208
- node = result[0][0]
209
  node_dict = node["properties"]
210
  logger.debug(
211
  "{%s}: query: {%s}, result: {%s}",
@@ -216,19 +210,18 @@ class GremlinStorage(BaseGraphStorage):
216
  return node_dict
217
 
218
  async def node_degree(self, node_id: str) -> int:
219
- entity_name_label = GremlinStorage._fix_label(node_id)
220
- query = f"""
221
- {self.traverse_source_name}
222
- .V().has('graph', '{self.graph_name}')
223
- .hasLabel('{entity_name_label}')
224
  .outE()
225
- .inV().has('graph', '{self.graph_name}')
226
  .count()
227
  .project('total_edge_count')
228
  .by()
229
  """
230
  result = await self._query(query)
231
- edge_count = result[0][0]["total_edge_count"]
232
 
233
  logger.debug(
234
  "{%s}:query:{%s}:result:{%s}",
@@ -259,31 +252,30 @@ class GremlinStorage(BaseGraphStorage):
259
  self, source_node_id: str, target_node_id: str
260
  ) -> Union[dict, None]:
261
  """
262
- Find all edges between nodes of two given labels
263
 
264
  Args:
265
- source_node_label (str): Label of the source nodes
266
- target_node_label (str): Label of the target nodes
267
 
268
  Returns:
269
- dict|None: Dict of found edge properties, or None of not found
270
  """
271
- entity_name_label_source = GremlinStorage._fix_label(source_node_id)
272
- entity_name_label_target = GremlinStorage._fix_label(target_node_id)
273
- query = f"""
274
- {self.traverse_source_name}
275
- .V().has('graph', '{self.graph_name}')
276
- .hasLabel('{entity_name_label_source}')
277
  .outE()
278
- .inV().has('graph', '{self.graph_name}')
279
- .hasLabel('{entity_name_label_target}')
280
  .limit(1)
281
  .project('edge_properties')
282
  .by(__.bothE().elementMap())
283
  """
284
  result = await self._query(query)
285
  if result:
286
- edge_properties = result[0][0]["edge_properties"]
287
  logger.debug(
288
  "{%s}:query:{%s}:result:{%s}",
289
  inspect.currentframe().f_code.co_name,
@@ -294,45 +286,31 @@ class GremlinStorage(BaseGraphStorage):
294
 
295
  async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
296
  """
297
- Retrieves all edges (relationships) for a particular node identified by its label.
298
  :return: List of tuples containing edge sources and targets
299
  """
300
- node_label = GremlinStorage._fix_label(source_node_id)
301
- query1 = f"""
302
- {self.traverse_source_name}
303
- .V().has('graph', '{self.graph_name}')
304
- .hasLabel('{node_label}')
305
- .out().has('graph', '{self.graph_name}')
306
- .project('connected_label')
307
- .by(__.label())
308
- """
309
- query2 = f"""
310
- {self.traverse_source_name}
311
- .V().has('graph', '{self.graph_name}')
312
- .as('connected')
313
- .out().has('graph', '{self.graph_name}')
314
- .hasLabel('{node_label}')
315
- .project('connected_label')
316
- .by(__.select('connected').label())
317
  """
318
- result1, result2 = await asyncio.gather(
319
- self._query(query1), self._query(query2)
320
- )
321
- edges1 = (
322
- [(node_label, res["connected_label"]) for res in result1[0]]
323
- if result1
324
- else []
325
- )
326
- edges2 = (
327
- [(res["connected_label"], node_label) for res in result2[0]]
328
- if result2
329
- else []
330
- )
331
 
332
- return edges1 + edges2
333
 
334
  @retry(
335
- stop=stop_after_attempt(3),
336
  wait=wait_exponential(multiplier=1, min=4, max=10),
337
  retry=retry_if_exception_type((GremlinServerError,)),
338
  )
@@ -341,28 +319,30 @@ class GremlinStorage(BaseGraphStorage):
341
  Upsert a node in the Gremlin graph.
342
 
343
  Args:
344
- node_id: The unique identifier for the node (used as label)
345
  node_data: Dictionary of node properties
346
  """
347
- label = GremlinStorage._fix_label(node_id)
348
  properties = GremlinStorage._convert_properties(node_data)
349
 
350
- query = f"""
351
- {self.traverse_source_name}
352
- .V().has('graph', '{self.graph_name}')
353
- .hasLabel('{label}').fold()
354
  .coalesce(
355
- unfold(),
356
- addV('{label}'))
357
- .property('graph', '{self.graph_name}')
 
 
358
  {properties}
359
  """
360
 
361
  try:
362
  await self._query(query)
363
  logger.debug(
364
- "Upserted node with label '{%s}' and properties: {%s}",
365
- label,
366
  properties,
367
  )
368
  except Exception as e:
@@ -370,7 +350,7 @@ class GremlinStorage(BaseGraphStorage):
370
  raise
371
 
372
  @retry(
373
- stop=stop_after_attempt(3),
374
  wait=wait_exponential(multiplier=1, min=4, max=10),
375
  retry=retry_if_exception_type((GremlinServerError,)),
376
  )
@@ -378,36 +358,35 @@ class GremlinStorage(BaseGraphStorage):
378
  self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
379
  ):
380
  """
381
- Upsert an edge and its properties between two nodes identified by their labels.
382
 
383
  Args:
384
- source_node_id (str): Label of the source node (used as identifier)
385
- target_node_id (str): Label of the target node (used as identifier)
386
  edge_data (dict): Dictionary of properties to set on the edge
387
  """
388
- source_node_label = GremlinStorage._fix_label(source_node_id)
389
- target_node_label = GremlinStorage._fix_label(target_node_id)
390
  edge_properties = GremlinStorage._convert_properties(edge_data)
391
 
392
- query = f"""
393
- {self.traverse_source_name}
394
- .V().has('graph', '{self.graph_name}')
395
- .hasLabel('{source_node_label}').as('source')
396
- .V().has('graph', '{self.graph_name}')
397
- .hasLabel('{target_node_label}').as('target')
398
  .coalesce(
399
- select('source').outE('DIRECTED').where(inV().as('target')),
400
- select('source').addE('DIRECTED').to(select('target'))
401
- )
402
- .property('graph', '{self.graph_name}')
403
  {edge_properties}
404
  """
405
  try:
406
  await self._query(query)
407
  logger.debug(
408
- "Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
409
- source_node_label,
410
- target_node_label,
411
  edge_properties,
412
  )
413
  except Exception as e:
 
2
  import inspect
3
  import json
4
  import os
 
5
  from dataclasses import dataclass
6
  from typing import Any, Dict, List, Tuple, Union
7
 
 
26
  def load_nx_graph(file_name):
27
  print("no preloading of graph with Gremlin in production")
28
 
 
 
 
29
  def __init__(self, namespace, global_config, embedding_func):
30
  super().__init__(
31
  namespace=namespace,
 
47
 
48
  # All vertices will have graph={GRAPH} property, so that we can
49
  # have several logical graphs for one source
50
+ GRAPH = GremlinStorage._to_value_map(os.environ["GREMLIN_GRAPH"])
 
 
 
51
 
 
52
  self.graph_name = GRAPH
53
 
54
  self._driver = client.Client(
 
79
 
80
  @staticmethod
81
  def _to_value_map(value: Any) -> str:
82
+ """Dump supported Python object as Gremlin valueMap"""
83
  json_str = json.dumps(value, ensure_ascii=False, sort_keys=False)
84
  parsed_str = json_str.replace("'", r"\'")
85
 
 
114
  """Create chained .property() commands from properties dict"""
115
  props = []
116
  for k, v in properties.items():
117
+ prop_name = GremlinStorage._to_value_map(k)
118
+ props.append(f".property({prop_name}, {GremlinStorage._to_value_map(v)})")
119
  return "".join(props)
120
 
121
  @staticmethod
122
+ def _fix_name(name: str) -> str:
123
+ """Strip double quotes and format as a proper field name"""
124
+ name = GremlinStorage._to_value_map(name.strip('"').replace(r"\'", "'"))
 
125
 
126
+ return name
127
 
128
  async def _query(self, query: str) -> List[Dict[str, Any]]:
129
  """
 
137
  """
138
 
139
  result = list(await asyncio.wrap_future(self._driver.submit_async(query)))
140
+ if result:
141
+ result = result[0]
142
 
143
  return result
144
 
145
  async def has_node(self, node_id: str) -> bool:
146
+ entity_name = GremlinStorage._fix_name(node_id)
147
 
148
+ query = f"""g
149
+ .V().has('graph', {self.graph_name})
150
+ .has('entity_name', {entity_name})
 
151
  .limit(1)
152
+ .count()
153
+ .project('has_node')
154
+ .by(__.choose(__.is(gt(0)), constant(true), constant(false)))
155
  """
156
  result = await self._query(query)
157
  logger.debug(
158
  "{%s}:query:{%s}:result:{%s}",
159
  inspect.currentframe().f_code.co_name,
160
  query,
161
+ result[0]["has_node"],
162
  )
163
 
164
+ return result[0]["has_node"]
165
 
166
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
167
+ entity_name_source = GremlinStorage._fix_name(source_node_id)
168
+ entity_name_target = GremlinStorage._fix_name(target_node_id)
169
+
170
+ query = f"""g
171
+ .V().has('graph', {self.graph_name})
172
+ .has('entity_name', {entity_name_source})
173
+ .outE()
174
+ .inV().has('graph', {self.graph_name})
175
+ .has('entity_name', {entity_name_target})
 
176
  .limit(1)
177
+ .count()
178
+ .project('has_edge')
179
+ .by(__.choose(__.is(gt(0)), constant(true), constant(false)))
180
  """
181
  result = await self._query(query)
182
  logger.debug(
183
  "{%s}:query:{%s}:result:{%s}",
184
  inspect.currentframe().f_code.co_name,
185
  query,
186
+ result[0]["has_edge"],
187
  )
188
 
189
+ return result[0]["has_edge"]
190
 
191
  async def get_node(self, node_id: str) -> Union[dict, None]:
192
+ entity_name = GremlinStorage._fix_name(node_id)
193
+ query = f"""g
194
+ .V().has('graph', {self.graph_name})
195
+ .has('entity_name', {entity_name})
 
196
  .limit(1)
197
  .project('properties')
198
  .by(elementMap())
199
  """
200
  result = await self._query(query)
201
  if result:
202
+ node = result[0]
203
  node_dict = node["properties"]
204
  logger.debug(
205
  "{%s}: query: {%s}, result: {%s}",
 
210
  return node_dict
211
 
212
  async def node_degree(self, node_id: str) -> int:
213
+ entity_name = GremlinStorage._fix_name(node_id)
214
+ query = f"""g
215
+ .V().has('graph', {self.graph_name})
216
+ .has('entity_name', {entity_name})
 
217
  .outE()
218
+ .inV().has('graph', {self.graph_name})
219
  .count()
220
  .project('total_edge_count')
221
  .by()
222
  """
223
  result = await self._query(query)
224
+ edge_count = result[0]["total_edge_count"]
225
 
226
  logger.debug(
227
  "{%s}:query:{%s}:result:{%s}",
 
252
  self, source_node_id: str, target_node_id: str
253
  ) -> Union[dict, None]:
254
  """
255
+ Find all edges between nodes of two given names
256
 
257
  Args:
258
+ source_node_id (str): Name of the source nodes
259
+ target_node_id (str): Name of the target nodes
260
 
261
  Returns:
262
+ dict|None: Dict of found edge properties, or None if not found
263
  """
264
+ entity_name_source = GremlinStorage._fix_name(source_node_id)
265
+ entity_name_target = GremlinStorage._fix_name(target_node_id)
266
+ query = f"""g
267
+ .V().has('graph', {self.graph_name})
268
+ .has('entity_name', {entity_name_source})
 
269
  .outE()
270
+ .inV().has('graph', {self.graph_name})
271
+ .has('entity_name', {entity_name_target})
272
  .limit(1)
273
  .project('edge_properties')
274
  .by(__.bothE().elementMap())
275
  """
276
  result = await self._query(query)
277
  if result:
278
+ edge_properties = result[0]["edge_properties"]
279
  logger.debug(
280
  "{%s}:query:{%s}:result:{%s}",
281
  inspect.currentframe().f_code.co_name,
 
286
 
287
  async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
288
  """
289
+ Retrieves all edges (relationships) for a particular node identified by its name.
290
  :return: List of tuples containing edge sources and targets
291
  """
292
+ node_name = GremlinStorage._fix_name(source_node_id)
293
+ query = f"""g
294
+ .E()
295
+ .filter(
296
+ __.or(
297
+ __.outV().has('graph', {self.graph_name})
298
+ .has('entity_name', {node_name}),
299
+ __.inV().has('graph', {self.graph_name})
300
+ .has('entity_name', {node_name})
301
+ )
302
+ )
303
+ .project('source_name', 'target_name')
304
+ .by(__.outV().values('entity_name'))
305
+ .by(__.inV().values('entity_name'))
 
 
 
306
  """
307
+ result = await self._query(query)
308
+ edges = [(res["source_name"], res["target_name"]) for res in result]
 
 
 
 
 
 
 
 
 
 
 
309
 
310
+ return edges
311
 
312
  @retry(
313
+ stop=stop_after_attempt(10),
314
  wait=wait_exponential(multiplier=1, min=4, max=10),
315
  retry=retry_if_exception_type((GremlinServerError,)),
316
  )
 
319
  Upsert a node in the Gremlin graph.
320
 
321
  Args:
322
+ node_id: The unique identifier for the node (used as name)
323
  node_data: Dictionary of node properties
324
  """
325
+ name = GremlinStorage._fix_name(node_id)
326
  properties = GremlinStorage._convert_properties(node_data)
327
 
328
+ query = f"""g
329
+ .V().has('graph', {self.graph_name})
330
+ .has('entity_name', {name})
331
+ .fold()
332
  .coalesce(
333
+ __.unfold(),
334
+ __.addV('ENTITY')
335
+ .property('graph', {self.graph_name})
336
+ .property('entity_name', {name})
337
+ )
338
  {properties}
339
  """
340
 
341
  try:
342
  await self._query(query)
343
  logger.debug(
344
+ "Upserted node with name {%s} and properties: {%s}",
345
+ name,
346
  properties,
347
  )
348
  except Exception as e:
 
350
  raise
351
 
352
  @retry(
353
+ stop=stop_after_attempt(10),
354
  wait=wait_exponential(multiplier=1, min=4, max=10),
355
  retry=retry_if_exception_type((GremlinServerError,)),
356
  )
 
358
  self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
359
  ):
360
  """
361
+ Upsert an edge and its properties between two nodes identified by their names.
362
 
363
  Args:
364
+ source_node_id (str): Name of the source node (used as identifier)
365
+ target_node_id (str): Name of the target node (used as identifier)
366
  edge_data (dict): Dictionary of properties to set on the edge
367
  """
368
+ source_node_name = GremlinStorage._fix_name(source_node_id)
369
+ target_node_name = GremlinStorage._fix_name(target_node_id)
370
  edge_properties = GremlinStorage._convert_properties(edge_data)
371
 
372
+ query = f"""g
373
+ .V().has('graph', {self.graph_name})
374
+ .has('entity_name', {source_node_name}).as('source')
375
+ .V().has('graph', {self.graph_name})
376
+ .has('entity_name', {target_node_name}).as('target')
 
377
  .coalesce(
378
+ __.select('source').outE('DIRECTED').where(__.inV().as('target')),
379
+ __.select('source').addE('DIRECTED').to(__.select('target'))
380
+ )
381
+ .property('graph', {self.graph_name})
382
  {edge_properties}
383
  """
384
  try:
385
  await self._query(query)
386
  logger.debug(
387
+ "Upserted edge from {%s} to {%s} with properties: {%s}",
388
+ source_node_name,
389
+ target_node_name,
390
  edge_properties,
391
  )
392
  except Exception as e: