alllexx88 commited on
Commit
d267c74
·
1 Parent(s): 2e57468

Add Gremlin graph storage

Browse files
examples/lightrag_ollama_gremlin_demo.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import inspect
3
+ import logging
4
+ import os
5
+
6
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.WARN)
7
+
8
+ from lightrag import LightRAG, QueryParam
9
+ from lightrag.llm import ollama_embedding, ollama_model_complete
10
+ from lightrag.utils import EmbeddingFunc
11
+
12
+ WORKING_DIR = "./dickens_gremlin"
13
+
14
+ if not os.path.exists(WORKING_DIR):
15
+ os.mkdir(WORKING_DIR)
16
+
17
+ # Gremlin
18
+ os.environ["GREMLIN_HOST"] = "localhost"
19
+ os.environ["GREMLIN_PORT"] = "8182"
20
+ os.environ["GREMLIN_GRAPH"] = "dickens"
21
+
22
+ # Creating a non-default source requires manual
23
+ # configuration and a restart on the server: use the dafault "g"
24
+ os.environ["GREMLIN_TRAVERSE_SOURCE"] = "g"
25
+
26
+ # No authorization by default on docker tinkerpop/gremlin-server
27
+ os.environ["GREMLIN_USER"] = ""
28
+ os.environ["GREMLIN_PASSWORD"] = ""
29
+
30
+ rag = LightRAG(
31
+ working_dir=WORKING_DIR,
32
+ llm_model_func=ollama_model_complete,
33
+ llm_model_name="llama3.1:8b",
34
+ llm_model_max_async=4,
35
+ llm_model_max_token_size=32768,
36
+ llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
37
+ embedding_func=EmbeddingFunc(
38
+ embedding_dim=768,
39
+ max_token_size=8192,
40
+ func=lambda texts: ollama_embedding(
41
+ texts, embed_model="nomic-embed-text", host="http://localhost:11434"
42
+ ),
43
+ ),
44
+ graph_storage="GremlinStorage",
45
+ )
46
+
47
+ with open("./book.txt", "r", encoding="utf-8") as f:
48
+ rag.insert(f.read())
49
+
50
+ # Perform naive search
51
+ print(
52
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
53
+ )
54
+
55
+ # Perform local search
56
+ print(
57
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
58
+ )
59
+
60
+ # Perform global search
61
+ print(
62
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
63
+ )
64
+
65
+ # Perform hybrid search
66
+ print(
67
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
68
+ )
69
+
70
+ # stream response
71
+ resp = rag.query(
72
+ "What are the top themes in this story?",
73
+ param=QueryParam(mode="hybrid", stream=True),
74
+ )
75
+
76
+
77
+ async def print_stream(stream):
78
+ async for chunk in stream:
79
+ print(chunk, end="", flush=True)
80
+
81
+
82
+ if inspect.isasyncgen(resp):
83
+ asyncio.run(print_stream(resp))
84
+ else:
85
+ print(resp)
lightrag/kg/gremlin_impl.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import inspect
3
+ import json
4
+ import os
5
+ import re
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, List, Tuple, Union
8
+
9
+ from gremlin_python.driver import client, serializer
10
+ from gremlin_python.driver.aiohttp.transport import AiohttpTransport
11
+ from gremlin_python.driver.protocol import GremlinServerError
12
+ from tenacity import (
13
+ retry,
14
+ retry_if_exception_type,
15
+ stop_after_attempt,
16
+ wait_exponential,
17
+ )
18
+
19
+ from lightrag.utils import logger
20
+
21
+ from ..base import BaseGraphStorage
22
+
23
+
24
+ @dataclass
25
+ class GremlinStorage(BaseGraphStorage):
26
+ @staticmethod
27
+ def load_nx_graph(file_name):
28
+ print("no preloading of graph with Gremlin in production")
29
+
30
+ # Will use this to make sure single quotes are properly escaped
31
+ escape_rx = re.compile(r"(^|[^\\])((\\\\)*\\)\\'")
32
+
33
+ def __init__(self, namespace, global_config, embedding_func):
34
+ super().__init__(
35
+ namespace=namespace,
36
+ global_config=global_config,
37
+ embedding_func=embedding_func,
38
+ )
39
+
40
+ self._driver = None
41
+ self._driver_lock = asyncio.Lock()
42
+
43
+ USER = os.environ.get("GREMLIN_USER", "")
44
+ PASSWORD = os.environ.get("GREMLIN_PASSWORD", "")
45
+ HOST = os.environ["GREMLIN_HOST"]
46
+ PORT = int(os.environ["GREMLIN_PORT"])
47
+
48
+ # TraversalSource, a custom one has to be created manually,
49
+ # default it "g"
50
+ SOURCE = os.environ.get("GREMLIN_TRAVERSE_SOURCE", "g")
51
+
52
+ # All vertices will have graph={GRAPH} property, so that we can
53
+ # have several logical graphs for one source
54
+ GRAPH = GremlinStorage.escape_rx.sub(
55
+ r"\1\2'",
56
+ os.environ["GREMLIN_GRAPH"].replace("'", r"\'"),
57
+ )
58
+
59
+ self.traverse_source_name = SOURCE
60
+ self.graph_name = GRAPH
61
+
62
+ self._driver = client.Client(
63
+ f"ws://{HOST}:{PORT}/gremlin",
64
+ SOURCE,
65
+ username=USER,
66
+ password=PASSWORD,
67
+ message_serializer=serializer.GraphSONSerializersV3d0(),
68
+ transport_factory=lambda: AiohttpTransport(call_from_event_loop=True),
69
+ )
70
+
71
+ def __post_init__(self):
72
+ self._node_embed_algorithms = {
73
+ "node2vec": self._node2vec_embed,
74
+ }
75
+
76
+ async def close(self):
77
+ if self._driver:
78
+ self._driver.close()
79
+ self._driver = None
80
+
81
+ async def __aexit__(self, exc_type, exc, tb):
82
+ if self._driver:
83
+ self._driver.close()
84
+
85
+ async def index_done_callback(self):
86
+ print("KG successfully indexed.")
87
+
88
+ @staticmethod
89
+ def _to_value_map(value: Any) -> str:
90
+ """Dump Python dict as Gremlin valueMap"""
91
+ json_str = json.dumps(value, ensure_ascii=False, sort_keys=False)
92
+ parsed_str = json_str.replace("'", r"\'")
93
+
94
+ # walk over the string and replace curly brackets with square brackets
95
+ # outside of strings, as well as replace double quotes with single quotes
96
+ # and "deescape" double quotes inside of strings
97
+ outside_str = True
98
+ escaped = False
99
+ remove_indices = []
100
+ for i, c in enumerate(parsed_str):
101
+ if escaped:
102
+ # previous character was an "odd" backslash
103
+ escaped = False
104
+ if c == '"':
105
+ # we want to "deescape" double quotes: store indices to delete
106
+ remove_indices.insert(0, i - 1)
107
+ elif c == "\\":
108
+ escaped = True
109
+ elif c == '"':
110
+ outside_str = not outside_str
111
+ parsed_str = parsed_str[:i] + "'" + parsed_str[i + 1 :]
112
+ elif c == "{" and outside_str:
113
+ parsed_str = parsed_str[:i] + "[" + parsed_str[i + 1 :]
114
+ elif c == "}" and outside_str:
115
+ parsed_str = parsed_str[:i] + "]" + parsed_str[i + 1 :]
116
+ for idx in remove_indices:
117
+ parsed_str = parsed_str[:idx] + parsed_str[idx + 1 :]
118
+ return parsed_str
119
+
120
+ @staticmethod
121
+ def _convert_properties(properties: Dict[str, Any]) -> str:
122
+ """Create chained .property() commands from properties dict"""
123
+ props = []
124
+ for k, v in properties.items():
125
+ prop_name = GremlinStorage.escape_rx.sub(r"\1\2'", k.replace("'", r"\'"))
126
+ props.append(f".property('{prop_name}', {GremlinStorage._to_value_map(v)})")
127
+ return "".join(props)
128
+
129
+ @staticmethod
130
+ def _fix_label(label: str) -> str:
131
+ """Strip double quotes and make sure single quotes are escaped"""
132
+ label = label.strip('"').replace("'", r"\'")
133
+ label = GremlinStorage.escape_rx.sub(r"\1\2'", label)
134
+
135
+ return label
136
+
137
+ async def _query(self, query: str) -> List[Dict[str, Any]]:
138
+ """
139
+ Query the Gremlin graph
140
+
141
+ Args:
142
+ query (str): a query to be executed
143
+
144
+ Returns:
145
+ List[Dict[str, Any]]: a list of dictionaries containing the result set
146
+ """
147
+
148
+ result = list(await asyncio.wrap_future(self._driver.submit_async(query)))
149
+
150
+ return result
151
+
152
+ async def has_node(self, node_id: str) -> bool:
153
+ entity_name_label = GremlinStorage._fix_label(node_id)
154
+
155
+ query = f"""
156
+ {self.traverse_source_name}
157
+ .V().has('graph', '{self.graph_name}')
158
+ .hasLabel('{entity_name_label}')
159
+ .limit(1)
160
+ .hasNext()
161
+ """
162
+ result = await self._query(query)
163
+ logger.debug(
164
+ "{%s}:query:{%s}:result:{%s}",
165
+ inspect.currentframe().f_code.co_name,
166
+ query,
167
+ result[0][0],
168
+ )
169
+
170
+ return result[0][0]
171
+
172
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
173
+ entity_name_label_source = GremlinStorage._fix_label(source_node_id)
174
+ entity_name_label_target = GremlinStorage._fix_label(target_node_id)
175
+
176
+ query = f"""
177
+ {self.traverse_source_name}
178
+ .V().has('graph', '{self.graph_name}')
179
+ .hasLabel('{entity_name_label_source}')
180
+ .bothE()
181
+ .otherV().has('graph', '{self.graph_name}')
182
+ .hasLabel('{entity_name_label_target}')
183
+ .limit(1)
184
+ .hasNext()
185
+ """
186
+ result = await self._query(query)
187
+ logger.debug(
188
+ "{%s}:query:{%s}:result:{%s}",
189
+ inspect.currentframe().f_code.co_name,
190
+ query,
191
+ result[0][0],
192
+ )
193
+
194
+ return result[0][0]
195
+
196
+ async def get_node(self, node_id: str) -> Union[dict, None]:
197
+ entity_name_label = GremlinStorage._fix_label(node_id)
198
+ query = f"""
199
+ {self.traverse_source_name}
200
+ .V().has('graph', '{self.graph_name}')
201
+ .hasLabel('{entity_name_label}')
202
+ .limit(1)
203
+ .project('properties')
204
+ .by(elementMap())
205
+ """
206
+ result = await self._query(query)
207
+ if result:
208
+ node = result[0][0]
209
+ node_dict = node["properties"]
210
+ logger.debug(
211
+ "{%s}: query: {%s}, result: {%s}",
212
+ inspect.currentframe().f_code.co_name,
213
+ query.format,
214
+ node_dict,
215
+ )
216
+ return node_dict
217
+
218
+ async def node_degree(self, node_id: str) -> int:
219
+ entity_name_label = GremlinStorage._fix_label(node_id)
220
+ query = f"""
221
+ {self.traverse_source_name}
222
+ .V().has('graph', '{self.graph_name}')
223
+ .hasLabel('{entity_name_label}')
224
+ .outE()
225
+ .inV().has('graph', '{self.graph_name}')
226
+ .count()
227
+ .project('total_edge_count')
228
+ .by()
229
+ """
230
+ result = await self._query(query)
231
+ edge_count = result[0][0]["total_edge_count"]
232
+
233
+ logger.debug(
234
+ "{%s}:query:{%s}:result:{%s}",
235
+ inspect.currentframe().f_code.co_name,
236
+ query,
237
+ edge_count,
238
+ )
239
+
240
+ return edge_count
241
+
242
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
243
+ src_degree = await self.node_degree(src_id)
244
+ trg_degree = await self.node_degree(tgt_id)
245
+
246
+ # Convert None to 0 for addition
247
+ src_degree = 0 if src_degree is None else src_degree
248
+ trg_degree = 0 if trg_degree is None else trg_degree
249
+
250
+ degrees = int(src_degree) + int(trg_degree)
251
+ logger.debug(
252
+ "{%s}:query:src_Degree+trg_degree:result:{%s}",
253
+ inspect.currentframe().f_code.co_name,
254
+ degrees,
255
+ )
256
+ return degrees
257
+
258
+ async def get_edge(
259
+ self, source_node_id: str, target_node_id: str
260
+ ) -> Union[dict, None]:
261
+ """
262
+ Find all edges between nodes of two given labels
263
+
264
+ Args:
265
+ source_node_label (str): Label of the source nodes
266
+ target_node_label (str): Label of the target nodes
267
+
268
+ Returns:
269
+ dict|None: Dict of found edge properties, or None of not found
270
+ """
271
+ entity_name_label_source = GremlinStorage._fix_label(source_node_id)
272
+ entity_name_label_target = GremlinStorage._fix_label(target_node_id)
273
+ query = f"""
274
+ {self.traverse_source_name}
275
+ .V().has('graph', '{self.graph_name}')
276
+ .hasLabel('{entity_name_label_source}')
277
+ .outE()
278
+ .inV().has('graph', '{self.graph_name}')
279
+ .hasLabel('{entity_name_label_target}')
280
+ .limit(1)
281
+ .project('edge_properties')
282
+ .by(__.bothE().elementMap())
283
+ """
284
+ result = await self._query(query)
285
+ if result:
286
+ edge_properties = result[0][0]["edge_properties"]
287
+ logger.debug(
288
+ "{%s}:query:{%s}:result:{%s}",
289
+ inspect.currentframe().f_code.co_name,
290
+ query,
291
+ edge_properties,
292
+ )
293
+ return edge_properties
294
+
295
+ async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
296
+ """
297
+ Retrieves all edges (relationships) for a particular node identified by its label.
298
+ :return: List of tuples containing edge sources and targets
299
+ """
300
+ node_label = GremlinStorage._fix_label(source_node_id)
301
+ query1 = f"""
302
+ {self.traverse_source_name}
303
+ .V().has('graph', '{self.graph_name}')
304
+ .hasLabel('{node_label}')
305
+ .out().has('graph', '{self.graph_name}')
306
+ .project('connected_label')
307
+ .by(__.label())
308
+ """
309
+ result1 = await self._query(query1)
310
+ edges1 = (
311
+ [(node_label, res["connected_label"]) for res in result1[0]]
312
+ if result1
313
+ else []
314
+ )
315
+
316
+ query2 = f"""
317
+ {self.traverse_source_name}
318
+ .V().has('graph', '{self.graph_name}')
319
+ .as('connected')
320
+ .out().has('graph', '{self.graph_name}')
321
+ .hasLabel('{node_label}')
322
+ .project('connected_label')
323
+ .by(__.select('connected').label())
324
+ """
325
+ result2 = await self._query(query2)
326
+ edges2 = (
327
+ [(res["connected_label"], node_label) for res in result2[0]]
328
+ if result2
329
+ else []
330
+ )
331
+
332
+ return edges1 + edges2
333
+
334
+ @retry(
335
+ stop=stop_after_attempt(3),
336
+ wait=wait_exponential(multiplier=1, min=4, max=10),
337
+ retry=retry_if_exception_type((GremlinServerError,)),
338
+ )
339
+ async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
340
+ """
341
+ Upsert a node in the Gremlin graph.
342
+
343
+ Args:
344
+ node_id: The unique identifier for the node (used as label)
345
+ node_data: Dictionary of node properties
346
+ """
347
+ label = GremlinStorage._fix_label(node_id)
348
+ properties = GremlinStorage._convert_properties(node_data)
349
+
350
+ query = f"""
351
+ {self.traverse_source_name}
352
+ .V().has('graph', '{self.graph_name}')
353
+ .hasLabel('{label}').fold()
354
+ .coalesce(
355
+ unfold(),
356
+ addV('{label}'))
357
+ .property('graph', '{self.graph_name}')
358
+ {properties}
359
+ """
360
+
361
+ try:
362
+ await self._query(query)
363
+ logger.debug(
364
+ "Upserted node with label '{%s}' and properties: {%s}",
365
+ label,
366
+ properties,
367
+ )
368
+ except Exception as e:
369
+ logger.error("Error during upsert: {%s}", e)
370
+ raise
371
+
372
+ @retry(
373
+ stop=stop_after_attempt(3),
374
+ wait=wait_exponential(multiplier=1, min=4, max=10),
375
+ retry=retry_if_exception_type((GremlinServerError,)),
376
+ )
377
+ async def upsert_edge(
378
+ self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
379
+ ):
380
+ """
381
+ Upsert an edge and its properties between two nodes identified by their labels.
382
+
383
+ Args:
384
+ source_node_id (str): Label of the source node (used as identifier)
385
+ target_node_id (str): Label of the target node (used as identifier)
386
+ edge_data (dict): Dictionary of properties to set on the edge
387
+ """
388
+ source_node_label = GremlinStorage._fix_label(source_node_id)
389
+ target_node_label = GremlinStorage._fix_label(target_node_id)
390
+ edge_properties = GremlinStorage._convert_properties(edge_data)
391
+
392
+ query = f"""
393
+ {self.traverse_source_name}
394
+ .V().has('graph', '{self.graph_name}')
395
+ .hasLabel('{source_node_label}').as('source')
396
+ .V().has('graph', '{self.graph_name}')
397
+ .hasLabel('{target_node_label}').as('target')
398
+ .coalesce(
399
+ select('source').outE('DIRECTED').where(inV().as('target')),
400
+ select('source').addE('DIRECTED').to(select('target'))
401
+ )
402
+ .property('graph', '{self.graph_name}')
403
+ {edge_properties}
404
+ """
405
+ try:
406
+ await self._query(query)
407
+ logger.debug(
408
+ "Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
409
+ source_node_label,
410
+ target_node_label,
411
+ edge_properties,
412
+ )
413
+ except Exception as e:
414
+ logger.error("Error during edge upsert: {%s}", e)
415
+ raise
416
+
417
+ async def _node2vec_embed(self):
418
+ print("Implemented but never called.")
lightrag/lightrag.py CHANGED
@@ -81,6 +81,7 @@ TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
81
  TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
82
  TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage")
83
  AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
 
84
 
85
 
86
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
@@ -284,6 +285,7 @@ class LightRAG:
284
  "OracleGraphStorage": OracleGraphStorage,
285
  "AGEStorage": AGEStorage,
286
  "TiDBGraphStorage": TiDBGraphStorage,
 
287
  # "ArangoDBStorage": ArangoDBStorage
288
  }
289
 
 
81
  TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
82
  TiDBGraphStorage = lazy_external_import(".kg.tidb_impl", "TiDBGraphStorage")
83
  AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
84
+ GremlinStorage = lazy_external_import(".kg.gremlin_impl", "GremlinStorage")
85
 
86
 
87
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
 
285
  "OracleGraphStorage": OracleGraphStorage,
286
  "AGEStorage": AGEStorage,
287
  "TiDBGraphStorage": TiDBGraphStorage,
288
+ "GremlinStorage": GremlinStorage,
289
  # "ArangoDBStorage": ArangoDBStorage
290
  }
291
 
requirements.txt CHANGED
@@ -4,6 +4,7 @@ aiohttp
4
 
5
  # database packages
6
  graspologic
 
7
  hnswlib
8
  nano-vectordb
9
  neo4j
 
4
 
5
  # database packages
6
  graspologic
7
+ gremlinpython
8
  hnswlib
9
  nano-vectordb
10
  neo4j