alllexx88 commited on
Commit
1ee03f4
·
1 Parent(s): a928fde

Add Apache AGE graph storage

Browse files
examples/lightrag_ollama_age_demo.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import inspect
3
+ import logging
4
+ import os
5
+
6
+ from lightrag import LightRAG, QueryParam
7
+ from lightrag.llm import ollama_embedding, ollama_model_complete
8
+ from lightrag.utils import EmbeddingFunc
9
+
10
+ WORKING_DIR = "./dickens_age"
11
+
12
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
13
+
14
+ if not os.path.exists(WORKING_DIR):
15
+ os.mkdir(WORKING_DIR)
16
+
17
+ # AGE
18
+ os.environ["AGE_POSTGRES_DB"] = "postgresDB"
19
+ os.environ["AGE_POSTGRES_USER"] = "postgresUser"
20
+ os.environ["AGE_POSTGRES_PASSWORD"] = "postgresPW"
21
+ os.environ["AGE_POSTGRES_HOST"] = "localhost"
22
+ os.environ["AGE_POSTGRES_PORT"] = "5455"
23
+ os.environ["AGE_GRAPH_NAME"] = "dickens"
24
+
25
+ rag = LightRAG(
26
+ working_dir=WORKING_DIR,
27
+ llm_model_func=ollama_model_complete,
28
+ llm_model_name="llama3.1:8b",
29
+ llm_model_max_async=4,
30
+ llm_model_max_token_size=32768,
31
+ llm_model_kwargs={"host": "http://localhost:11434", "options": {"num_ctx": 32768}},
32
+ embedding_func=EmbeddingFunc(
33
+ embedding_dim=768,
34
+ max_token_size=8192,
35
+ func=lambda texts: ollama_embedding(
36
+ texts, embed_model="nomic-embed-text", host="http://localhost:11434"
37
+ ),
38
+ ),
39
+ graph_storage="AGEStorage",
40
+ )
41
+
42
+ with open("./book.txt", "r", encoding="utf-8") as f:
43
+ rag.insert(f.read())
44
+
45
+ # Perform naive search
46
+ print(
47
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
48
+ )
49
+
50
+ # Perform local search
51
+ print(
52
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
53
+ )
54
+
55
+ # Perform global search
56
+ print(
57
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
58
+ )
59
+
60
+ # Perform hybrid search
61
+ print(
62
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
63
+ )
64
+
65
+ # stream response
66
+ resp = rag.query(
67
+ "What are the top themes in this story?",
68
+ param=QueryParam(mode="hybrid", stream=True),
69
+ )
70
+
71
+
72
+ async def print_stream(stream):
73
+ async for chunk in stream:
74
+ print(chunk, end="", flush=True)
75
+
76
+
77
+ if inspect.isasyncgen(resp):
78
+ asyncio.run(print_stream(resp))
79
+ else:
80
+ print(resp)
lightrag/kg/age_impl.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import inspect
3
+ import json
4
+ import os
5
+ from contextlib import asynccontextmanager
6
+ from dataclasses import dataclass
7
+ from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union
8
+
9
+ import psycopg
10
+ from psycopg.rows import namedtuple_row
11
+ from psycopg_pool import AsyncConnectionPool, PoolTimeout
12
+ from tenacity import (
13
+ retry,
14
+ retry_if_exception_type,
15
+ stop_after_attempt,
16
+ wait_exponential,
17
+ )
18
+
19
+ from lightrag.utils import logger
20
+
21
+ from ..base import BaseGraphStorage
22
+
23
+
24
+ class AGEQueryException(Exception):
25
+ """Exception for the AGE queries."""
26
+
27
+ def __init__(self, exception: Union[str, Dict]) -> None:
28
+ if isinstance(exception, dict):
29
+ self.message = exception["message"] if "message" in exception else "unknown"
30
+ self.details = exception["details"] if "details" in exception else "unknown"
31
+ else:
32
+ self.message = exception
33
+ self.details = "unknown"
34
+
35
+ def get_message(self) -> str:
36
+ return self.message
37
+
38
+ def get_details(self) -> Any:
39
+ return self.details
40
+
41
+
42
+ @dataclass
43
+ class AGEStorage(BaseGraphStorage):
44
+ @staticmethod
45
+ def load_nx_graph(file_name):
46
+ print("no preloading of graph with AGE in production")
47
+
48
+ def __init__(self, namespace, global_config, embedding_func):
49
+ super().__init__(
50
+ namespace=namespace,
51
+ global_config=global_config,
52
+ embedding_func=embedding_func,
53
+ )
54
+ self._driver = None
55
+ self._driver_lock = asyncio.Lock()
56
+ DB = os.environ["AGE_POSTGRES_DB"].replace("\\", "\\\\").replace("'", "\\'")
57
+ USER = os.environ["AGE_POSTGRES_USER"].replace("\\", "\\\\").replace("'", "\\'")
58
+ PASSWORD = (
59
+ os.environ["AGE_POSTGRES_PASSWORD"]
60
+ .replace("\\", "\\\\")
61
+ .replace("'", "\\'")
62
+ )
63
+ HOST = os.environ["AGE_POSTGRES_HOST"].replace("\\", "\\\\").replace("'", "\\'")
64
+ PORT = int(os.environ["AGE_POSTGRES_PORT"])
65
+ self.graph_name = os.environ["AGE_GRAPH_NAME"]
66
+
67
+ connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
68
+
69
+ self._driver = AsyncConnectionPool(connection_string, open=False)
70
+
71
+ return None
72
+
73
+ def __post_init__(self):
74
+ self._node_embed_algorithms = {
75
+ "node2vec": self._node2vec_embed,
76
+ }
77
+
78
+ async def close(self):
79
+ if self._driver:
80
+ await self._driver.close()
81
+ self._driver = None
82
+
83
+ async def __aexit__(self, exc_type, exc, tb):
84
+ if self._driver:
85
+ await self._driver.close()
86
+
87
+ async def index_done_callback(self):
88
+ print("KG successfully indexed.")
89
+
90
+ @staticmethod
91
+ def _record_to_dict(record: NamedTuple) -> Dict[str, Any]:
92
+ """
93
+ Convert a record returned from an age query to a dictionary
94
+
95
+ Args:
96
+ record (): a record from an age query result
97
+
98
+ Returns:
99
+ Dict[str, Any]: a dictionary representation of the record where
100
+ the dictionary key is the field name and the value is the
101
+ value converted to a python type
102
+ """
103
+ # result holder
104
+ d = {}
105
+
106
+ # prebuild a mapping of vertex_id to vertex mappings to be used
107
+ # later to build edges
108
+ vertices = {}
109
+ for k in record._fields:
110
+ v = getattr(record, k)
111
+ # agtype comes back '{key: value}::type' which must be parsed
112
+ if isinstance(v, str) and "::" in v:
113
+ dtype = v.split("::")[-1]
114
+ v = v.split("::")[0]
115
+ if dtype == "vertex":
116
+ vertex = json.loads(v)
117
+ vertices[vertex["id"]] = vertex.get("properties")
118
+
119
+ # iterate returned fields and parse appropriately
120
+ for k in record._fields:
121
+ v = getattr(record, k)
122
+ if isinstance(v, str) and "::" in v:
123
+ dtype = v.split("::")[-1]
124
+ v = v.split("::")[0]
125
+ else:
126
+ dtype = ""
127
+
128
+ if dtype == "vertex":
129
+ vertex = json.loads(v)
130
+ field = json.loads(v).get("properties")
131
+ if not field:
132
+ field = {}
133
+ field["label"] = AGEStorage._decode_graph_label(vertex["label"])
134
+ d[k] = field
135
+ # convert edge from id-label->id by replacing id with node information
136
+ # we only do this if the vertex was also returned in the query
137
+ # this is an attempt to be consistent with neo4j implementation
138
+ elif dtype == "edge":
139
+ edge = json.loads(v)
140
+ d[k] = (
141
+ vertices.get(edge["start_id"], {}),
142
+ edge[
143
+ "label"
144
+ ], # we don't use decode_graph_label(), since edge label is always "DIRECTED"
145
+ vertices.get(edge["end_id"], {}),
146
+ )
147
+ else:
148
+ d[k] = json.loads(v) if isinstance(v, str) else v
149
+
150
+ return d
151
+
152
+ @staticmethod
153
+ def _format_properties(
154
+ properties: Dict[str, Any], _id: Union[str, None] = None
155
+ ) -> str:
156
+ """
157
+ Convert a dictionary of properties to a string representation that
158
+ can be used in a cypher query insert/merge statement.
159
+
160
+ Args:
161
+ properties (Dict[str,str]): a dictionary containing node/edge properties
162
+ id (Union[str, None]): the id of the node or None if none exists
163
+
164
+ Returns:
165
+ str: the properties dictionary as a properly formatted string
166
+ """
167
+ props = []
168
+ # wrap property key in backticks to escape
169
+ for k, v in properties.items():
170
+ prop = f"`{k}`: {json.dumps(v)}"
171
+ props.append(prop)
172
+ if _id is not None and "id" not in properties:
173
+ props.append(
174
+ f"id: {json.dumps(_id)}" if isinstance(_id, str) else f"id: {_id}"
175
+ )
176
+ return "{" + ", ".join(props) + "}"
177
+
178
+ @staticmethod
179
+ def _encode_graph_label(label: str) -> str:
180
+ """
181
+ Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
182
+
183
+ Args:
184
+ label (str): the original label
185
+
186
+ Returns:
187
+ str: the encoded label
188
+ """
189
+ return "x" + label.encode().hex()
190
+
191
+ @staticmethod
192
+ def _decode_graph_label(encoded_label: str) -> str:
193
+ """
194
+ Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
195
+
196
+ Args:
197
+ encoded_label (str): the encoded label
198
+
199
+ Returns:
200
+ str: the decoded label
201
+ """
202
+ return bytes.fromhex(encoded_label.removeprefix("x")).decode()
203
+
204
+ @staticmethod
205
+ def _get_col_name(field: str, idx: int) -> str:
206
+ """
207
+ Convert a cypher return field to a pgsql select field
208
+ If possible keep the cypher column name, but create a generic name if necessary
209
+
210
+ Args:
211
+ field (str): a return field from a cypher query to be formatted for pgsql
212
+ idx (int): the position of the field in the return statement
213
+
214
+ Returns:
215
+ str: the field to be used in the pgsql select statement
216
+ """
217
+ # remove white space
218
+ field = field.strip()
219
+ # if an alias is provided for the field, use it
220
+ if " as " in field:
221
+ return field.split(" as ")[-1].strip()
222
+ # if the return value is an unnamed primitive, give it a generic name
223
+ if field.isnumeric() or field in ("true", "false", "null"):
224
+ return f"column_{idx}"
225
+ # otherwise return the value stripping out some common special chars
226
+ return field.replace("(", "_").replace(")", "")
227
+
228
+ @staticmethod
229
+ def _wrap_query(query: str, graph_name: str, **params: str) -> str:
230
+ """
231
+ Convert a cypher query to an Apache Age compatible
232
+ sql query by wrapping the cypher query in ag_catalog.cypher,
233
+ casting results to agtype and building a select statement
234
+
235
+ Args:
236
+ query (str): a valid cypher query
237
+ graph_name (str): the name of the graph to query
238
+ params (dict): parameters for the query
239
+
240
+ Returns:
241
+ str: an equivalent pgsql query
242
+ """
243
+
244
+ # pgsql template
245
+ template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
246
+ {query}
247
+ $$) AS ({fields});"""
248
+
249
+ # if there are any returned fields they must be added to the pgsql query
250
+ if "return" in query.lower():
251
+ # parse return statement to identify returned fields
252
+ fields = (
253
+ query.lower()
254
+ .split("return")[-1]
255
+ .split("distinct")[-1]
256
+ .split("order by")[0]
257
+ .split("skip")[0]
258
+ .split("limit")[0]
259
+ .split(",")
260
+ )
261
+
262
+ # raise exception if RETURN * is found as we can't resolve the fields
263
+ if "*" in [x.strip() for x in fields]:
264
+ raise ValueError(
265
+ "AGE graph does not support 'RETURN *'"
266
+ + " statements in Cypher queries"
267
+ )
268
+
269
+ # get pgsql formatted field names
270
+ fields = [
271
+ AGEStorage._get_col_name(field, idx) for idx, field in enumerate(fields)
272
+ ]
273
+
274
+ # build resulting pgsql relation
275
+ fields_str = ", ".join(
276
+ [field.split(".")[-1] + " agtype" for field in fields]
277
+ )
278
+
279
+ # if no return statement we still need to return a single field of type agtype
280
+ else:
281
+ fields_str = "a agtype"
282
+
283
+ select_str = "*"
284
+
285
+ query = query.format(**params)
286
+
287
+ return template.format(
288
+ graph_name=graph_name,
289
+ query=query,
290
+ fields=fields_str,
291
+ projection=select_str,
292
+ )
293
+
294
+ async def _query(self, query: str, **params: str) -> List[Dict[str, Any]]:
295
+ """
296
+ Query the graph by taking a cypher query, converting it to an
297
+ age compatible query, executing it and converting the result
298
+
299
+ Args:
300
+ query (str): a cypher query to be executed
301
+ params (dict): parameters for the query
302
+
303
+ Returns:
304
+ List[Dict[str, Any]]: a list of dictionaries containing the result set
305
+ """
306
+ # convert cypher query to pgsql/age query
307
+ wrapped_query = self._wrap_query(query, self.graph_name, **params)
308
+
309
+ await self._driver.open()
310
+
311
+ # create graph if it doesn't exist
312
+ async with self._get_pool_connection() as conn:
313
+ async with conn.cursor() as curs:
314
+ try:
315
+ await curs.execute('SET search_path = ag_catalog, "$user", public')
316
+ await curs.execute(f"SELECT create_graph('{self.graph_name}')")
317
+ await conn.commit()
318
+ except (
319
+ psycopg.errors.InvalidSchemaName,
320
+ psycopg.errors.UniqueViolation,
321
+ ):
322
+ await conn.rollback()
323
+
324
+ # execute the query, rolling back on an error
325
+ async with self._get_pool_connection() as conn:
326
+ async with conn.cursor(row_factory=namedtuple_row) as curs:
327
+ try:
328
+ await curs.execute('SET search_path = ag_catalog, "$user", public')
329
+ await curs.execute(wrapped_query)
330
+ await conn.commit()
331
+ except psycopg.Error as e:
332
+ await conn.rollback()
333
+ raise AGEQueryException(
334
+ {
335
+ "message": f"Error executing graph query: {query.format(**params)}",
336
+ "detail": str(e),
337
+ }
338
+ ) from e
339
+
340
+ data = await curs.fetchall()
341
+ if data is None:
342
+ result = []
343
+ # decode records
344
+ else:
345
+ result = [AGEStorage._record_to_dict(d) for d in data]
346
+
347
+ return result
348
+
349
+ async def has_node(self, node_id: str) -> bool:
350
+ entity_name_label = node_id.strip('"')
351
+
352
+ query = "MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"
353
+ single_result = (
354
+ await self._query(
355
+ query, label=AGEStorage._encode_graph_label(entity_name_label)
356
+ )
357
+ )[0]
358
+ logger.debug(
359
+ "{%s}:query:{%s}:result:{%s}",
360
+ inspect.currentframe().f_code.co_name,
361
+ query,
362
+ single_result[0],
363
+ )
364
+
365
+ return single_result["node_exists"].lower() == "true"
366
+
367
+ async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
368
+ entity_name_label_source = source_node_id.strip('"')
369
+ entity_name_label_target = target_node_id.strip('"')
370
+
371
+ query = (
372
+ "MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`) "
373
+ "RETURN COUNT(r) > 0 AS edgeExists"
374
+ )
375
+ single_result = (
376
+ await self._query(
377
+ query,
378
+ src_label=AGEStorage._encode_graph_label(entity_name_label_source),
379
+ tgt_label=AGEStorage._encode_graph_label(entity_name_label_target),
380
+ )
381
+ )[0]
382
+ logger.debug(
383
+ "{%s}:query:{query}:result:{%s}",
384
+ inspect.currentframe().f_code.co_name,
385
+ single_result[0],
386
+ )
387
+ return single_result["edgeExists"].lower() == "true"
388
+
389
+ async def get_node(self, node_id: str) -> Union[dict, None]:
390
+ entity_name_label = node_id.strip('"')
391
+ query = "MATCH (n:`{label}`) RETURN n"
392
+ record = await self._query(
393
+ query, label=AGEStorage._encode_graph_label(entity_name_label)
394
+ )
395
+ if record:
396
+ node = record[0]
397
+ node_dict = node["n"]
398
+ logger.debug(
399
+ "{%s}: query: {%s}, result: {%s}",
400
+ inspect.currentframe().f_code.co_name,
401
+ query,
402
+ node_dict,
403
+ )
404
+ return node_dict
405
+ return None
406
+
407
+ async def node_degree(self, node_id: str) -> int:
408
+ entity_name_label = node_id.strip('"')
409
+
410
+ query = """
411
+ MATCH (n:`{label}`)-[]->(x)
412
+ RETURN count(x) AS total_edge_count
413
+ """
414
+ record = (
415
+ await self._query(
416
+ query, label=AGEStorage._encode_graph_label(entity_name_label)
417
+ )
418
+ )[0]
419
+ if record:
420
+ edge_count = int(record["total_edge_count"])
421
+ logger.debug(
422
+ "{%s}:query:{%s}:result:{%s}",
423
+ inspect.currentframe().f_code.co_name,
424
+ query,
425
+ edge_count,
426
+ )
427
+ return edge_count
428
+
429
+ async def edge_degree(self, src_id: str, tgt_id: str) -> int:
430
+ entity_name_label_source = src_id.strip('"')
431
+ entity_name_label_target = tgt_id.strip('"')
432
+ src_degree = await self.node_degree(entity_name_label_source)
433
+ trg_degree = await self.node_degree(entity_name_label_target)
434
+
435
+ # Convert None to 0 for addition
436
+ src_degree = 0 if src_degree is None else src_degree
437
+ trg_degree = 0 if trg_degree is None else trg_degree
438
+
439
+ degrees = int(src_degree) + int(trg_degree)
440
+ logger.debug(
441
+ "{%s}:query:src_Degree+trg_degree:result:{%s}",
442
+ inspect.currentframe().f_code.co_name,
443
+ degrees,
444
+ )
445
+ return degrees
446
+
447
+ async def get_edge(
448
+ self, source_node_id: str, target_node_id: str
449
+ ) -> Union[dict, None]:
450
+ """
451
+ Find all edges between nodes of two given labels
452
+
453
+ Args:
454
+ source_node_label (str): Label of the source nodes
455
+ target_node_label (str): Label of the target nodes
456
+
457
+ Returns:
458
+ list: List of all relationships/edges found
459
+ """
460
+ entity_name_label_source = source_node_id.strip('"')
461
+ entity_name_label_target = target_node_id.strip('"')
462
+
463
+ query = """
464
+ MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
465
+ RETURN properties(r) as edge_properties
466
+ LIMIT 1
467
+ """
468
+
469
+ record = await self._query(
470
+ query,
471
+ src_label=AGEStorage._encode_graph_label(entity_name_label_source),
472
+ tgt_label=AGEStorage._encode_graph_label(entity_name_label_target),
473
+ )
474
+ if record and record[0] and record[0]["edge_properties"]:
475
+ result = record[0]["edge_properties"]
476
+ logger.debug(
477
+ "{%s}:query:{%s}:result:{%s}",
478
+ inspect.currentframe().f_code.co_name,
479
+ query,
480
+ result,
481
+ )
482
+ return result
483
+
484
+ async def get_node_edges(self, source_node_id: str) -> List[Tuple[str, str]]:
485
+ """
486
+ Retrieves all edges (relationships) for a particular node identified by its label.
487
+ :return: List of dictionaries containing edge information
488
+ """
489
+ node_label = source_node_id.strip('"')
490
+
491
+ query = """MATCH (n:`{label}`)
492
+ OPTIONAL MATCH (n)-[r]-(connected)
493
+ RETURN n, r, connected"""
494
+ results = await self._query(
495
+ query, label=AGEStorage._encode_graph_label(node_label)
496
+ )
497
+ edges = []
498
+ for record in results:
499
+ source_node = record["n"] if record["n"] else None
500
+ connected_node = record["connected"] if record["connected"] else None
501
+
502
+ source_label = (
503
+ source_node["label"] if source_node and source_node["label"] else None
504
+ )
505
+ target_label = (
506
+ connected_node["label"]
507
+ if connected_node and connected_node["label"]
508
+ else None
509
+ )
510
+
511
+ if source_label and target_label:
512
+ edges.append((source_label, target_label))
513
+
514
+ return edges
515
+
516
+ @retry(
517
+ stop=stop_after_attempt(3),
518
+ wait=wait_exponential(multiplier=1, min=4, max=10),
519
+ retry=retry_if_exception_type((AGEQueryException,)),
520
+ )
521
+ async def upsert_node(self, node_id: str, node_data: Dict[str, Any]):
522
+ """
523
+ Upsert a node in the AGE database.
524
+
525
+ Args:
526
+ node_id: The unique identifier for the node (used as label)
527
+ node_data: Dictionary of node properties
528
+ """
529
+ label = node_id.strip('"')
530
+ properties = node_data
531
+
532
+ query = """
533
+ MERGE (n:`{label}`)
534
+ SET n += {properties}
535
+ """
536
+ try:
537
+ await self._query(
538
+ query,
539
+ label=AGEStorage._encode_graph_label(label),
540
+ properties=AGEStorage._format_properties(properties),
541
+ )
542
+ logger.debug(
543
+ "Upserted node with label '{%s}' and properties: {%s}",
544
+ label,
545
+ properties,
546
+ )
547
+ except Exception as e:
548
+ logger.error("Error during upsert: {%s}", e)
549
+ raise
550
+
551
+ @retry(
552
+ stop=stop_after_attempt(3),
553
+ wait=wait_exponential(multiplier=1, min=4, max=10),
554
+ retry=retry_if_exception_type((AGEQueryException,)),
555
+ )
556
+ async def upsert_edge(
557
+ self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
558
+ ):
559
+ """
560
+ Upsert an edge and its properties between two nodes identified by their labels.
561
+
562
+ Args:
563
+ source_node_id (str): Label of the source node (used as identifier)
564
+ target_node_id (str): Label of the target node (used as identifier)
565
+ edge_data (dict): Dictionary of properties to set on the edge
566
+ """
567
+ source_node_label = source_node_id.strip('"')
568
+ target_node_label = target_node_id.strip('"')
569
+ edge_properties = edge_data
570
+
571
+ query = """
572
+ MATCH (source:`{src_label}`)
573
+ WITH source
574
+ MATCH (target:`{tgt_label}`)
575
+ MERGE (source)-[r:DIRECTED]->(target)
576
+ SET r += {properties}
577
+ RETURN r
578
+ """
579
+ try:
580
+ await self._query(
581
+ query,
582
+ src_label=AGEStorage._encode_graph_label(source_node_label),
583
+ tgt_label=AGEStorage._encode_graph_label(target_node_label),
584
+ properties=AGEStorage._format_properties(edge_properties),
585
+ )
586
+ logger.debug(
587
+ "Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
588
+ source_node_label,
589
+ target_node_label,
590
+ edge_properties,
591
+ )
592
+ except Exception as e:
593
+ logger.error("Error during edge upsert: {%s}", e)
594
+ raise
595
+
596
+ async def _node2vec_embed(self):
597
+ print("Implemented but never called.")
598
+
599
+ @asynccontextmanager
600
+ async def _get_pool_connection(self, timeout: Optional[float] = None):
601
+ """Workaround for a psycopg_pool bug"""
602
+
603
+ try:
604
+ connection = await self._driver.getconn(timeout=timeout)
605
+ except PoolTimeout:
606
+ await self._driver._add_connection(None) # workaround...
607
+ connection = await self._driver.getconn(timeout=timeout)
608
+
609
+ try:
610
+ async with connection:
611
+ yield connection
612
+ finally:
613
+ await self._driver.putconn(connection)
lightrag/lightrag.py CHANGED
@@ -79,6 +79,7 @@ MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
79
  ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
80
  TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
81
  TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
 
82
 
83
 
84
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
@@ -273,6 +274,7 @@ class LightRAG:
273
  "NetworkXStorage": NetworkXStorage,
274
  "Neo4JStorage": Neo4JStorage,
275
  "OracleGraphStorage": OracleGraphStorage,
 
276
  # "ArangoDBStorage": ArangoDBStorage
277
  }
278
 
 
79
  ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
80
  TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
81
  TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
82
+ AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
83
 
84
 
85
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
 
274
  "NetworkXStorage": NetworkXStorage,
275
  "Neo4JStorage": Neo4JStorage,
276
  "OracleGraphStorage": OracleGraphStorage,
277
+ "AGEStorage": AGEStorage,
278
  # "ArangoDBStorage": ArangoDBStorage
279
  }
280
 
requirements.txt CHANGED
@@ -11,6 +11,7 @@ networkx
11
  ollama
12
  openai
13
  oracledb
 
14
  pymilvus
15
  pymongo
16
  pymysql
 
11
  ollama
12
  openai
13
  oracledb
14
+ psycopg[binary,pool]
15
  pymilvus
16
  pymongo
17
  pymysql