samuel-z-chen commited on
Commit
d19a515
·
1 Parent(s): 4d8e9a6

Fix the lint issue

Browse files
examples/lightrag_zhipu_postgres_demo.py CHANGED
@@ -53,7 +53,7 @@ async def main():
53
  kv_storage="PGKVStorage",
54
  doc_status_storage="PGDocStatusStorage",
55
  graph_storage="PGGraphStorage",
56
- vector_storage="PGVectorStorage"
57
  )
58
  # Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool
59
  rag.doc_status.db = postgres_db
@@ -77,27 +77,35 @@ async def main():
77
  start_time = time.time()
78
  # Perform naive search
79
  print(
80
- await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="naive"))
 
 
81
  )
82
  print(f"Naive Query Time: {time.time() - start_time} seconds")
83
  # Perform local search
84
  print("**** Start Local Query ****")
85
  start_time = time.time()
86
  print(
87
- await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="local"))
 
 
88
  )
89
  print(f"Local Query Time: {time.time() - start_time} seconds")
90
  # Perform global search
91
  print("**** Start Global Query ****")
92
  start_time = time.time()
93
  print(
94
- await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="global"))
 
 
95
  )
96
  print(f"Global Query Time: {time.time() - start_time}")
97
  # Perform hybrid search
98
  print("**** Start Hybrid Query ****")
99
  print(
100
- await rag.aquery("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
 
 
101
  )
102
  print(f"Hybrid Query Time: {time.time() - start_time} seconds")
103
 
 
53
  kv_storage="PGKVStorage",
54
  doc_status_storage="PGDocStatusStorage",
55
  graph_storage="PGGraphStorage",
56
+ vector_storage="PGVectorStorage",
57
  )
58
  # Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool
59
  rag.doc_status.db = postgres_db
 
77
  start_time = time.time()
78
  # Perform naive search
79
  print(
80
+ await rag.aquery(
81
+ "What are the top themes in this story?", param=QueryParam(mode="naive")
82
+ )
83
  )
84
  print(f"Naive Query Time: {time.time() - start_time} seconds")
85
  # Perform local search
86
  print("**** Start Local Query ****")
87
  start_time = time.time()
88
  print(
89
+ await rag.aquery(
90
+ "What are the top themes in this story?", param=QueryParam(mode="local")
91
+ )
92
  )
93
  print(f"Local Query Time: {time.time() - start_time} seconds")
94
  # Perform global search
95
  print("**** Start Global Query ****")
96
  start_time = time.time()
97
  print(
98
+ await rag.aquery(
99
+ "What are the top themes in this story?", param=QueryParam(mode="global")
100
+ )
101
  )
102
  print(f"Global Query Time: {time.time() - start_time}")
103
  # Perform hybrid search
104
  print("**** Start Hybrid Query ****")
105
  print(
106
+ await rag.aquery(
107
+ "What are the top themes in this story?", param=QueryParam(mode="hybrid")
108
+ )
109
  )
110
  print(f"Hybrid Query Time: {time.time() - start_time} seconds")
111
 
lightrag/kg/postgres_impl.py CHANGED
@@ -19,7 +19,11 @@ from tenacity import (
19
  from ..utils import logger
20
  from ..base import (
21
  BaseKVStorage,
22
- BaseVectorStorage, DocStatusStorage, DocStatus, DocProcessingStatus, BaseGraphStorage,
 
 
 
 
23
  )
24
 
25
  if sys.platform.startswith("win"):
@@ -36,14 +40,15 @@ class PostgreSQLDB:
36
  self.user = config.get("user", "postgres")
37
  self.password = config.get("password", None)
38
  self.database = config.get("database", "postgres")
39
- self.workspace = config.get("workspace", 'default')
40
  self.max = 12
41
  self.increment = 1
42
  logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier")
43
 
44
  if self.user is None or self.password is None or self.database is None:
45
- raise ValueError("Missing database user, password, or database in addon_params")
46
-
 
47
 
48
  async def initdb(self):
49
  try:
@@ -54,12 +59,16 @@ class PostgreSQLDB:
54
  host=self.host,
55
  port=self.port,
56
  min_size=1,
57
- max_size=self.max
58
  )
59
 
60
- logger.info(f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}")
 
 
61
  except Exception as e:
62
- logger.error(f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}")
 
 
63
  logger.error(f"PostgreSQL database error: {e}")
64
  raise
65
 
@@ -79,9 +88,13 @@ class PostgreSQLDB:
79
 
80
  logger.info("Finished checking all tables in PostgreSQL database")
81
 
82
-
83
  async def query(
84
- self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False, graph_name: str = None
 
 
 
 
 
85
  ) -> Union[dict, None, list[dict]]:
86
  async with self.pool.acquire() as connection:
87
  try:
@@ -111,7 +124,13 @@ class PostgreSQLDB:
111
  print(params)
112
  raise
113
 
114
- async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False, graph_name: str = None):
 
 
 
 
 
 
115
  try:
116
  async with self.pool.acquire() as connection:
117
  if for_age:
@@ -130,7 +149,7 @@ class PostgreSQLDB:
130
  @staticmethod
131
  async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
132
  try:
133
- await conn.execute(f'SET search_path = ag_catalog, "$user", public')
134
  await conn.execute(f"""select create_graph('{graph_name}')""")
135
  except asyncpg.exceptions.InvalidSchemaNameError:
136
  pass
@@ -138,7 +157,7 @@ class PostgreSQLDB:
138
 
139
  @dataclass
140
  class PGKVStorage(BaseKVStorage):
141
- db:PostgreSQLDB = None
142
 
143
  def __post_init__(self):
144
  self._data = {}
@@ -180,7 +199,7 @@ class PGKVStorage(BaseKVStorage):
180
  dict_res[mode] = {}
181
  for row in array_res:
182
  dict_res[row["mode"]][row["id"]] = row
183
- res = [{k:v} for k,v in dict_res.items()]
184
  else:
185
  res = await self.db.query(sql, params, multirows=True)
186
  if res:
@@ -191,7 +210,8 @@ class PGKVStorage(BaseKVStorage):
191
  async def filter_keys(self, keys: List[str]) -> Set[str]:
192
  """Filter out duplicated content"""
193
  sql = SQL_TEMPLATES["filter_keys"].format(
194
- table_name=NAMESPACE_TABLE_MAP[self.namespace], ids=",".join([f"'{id}'" for id in keys])
 
195
  )
196
  params = {"workspace": self.db.workspace}
197
  try:
@@ -207,7 +227,6 @@ class PGKVStorage(BaseKVStorage):
207
  print(sql)
208
  print(params)
209
 
210
-
211
  ################ INSERT METHODS ################
212
  async def upsert(self, data: Dict[str, dict]):
213
  left_data = {k: v for k, v in data.items() if k not in self._data}
@@ -246,7 +265,7 @@ class PGKVStorage(BaseKVStorage):
246
  @dataclass
247
  class PGVectorStorage(BaseVectorStorage):
248
  cosine_better_than_threshold: float = 0.2
249
- db:PostgreSQLDB = None
250
 
251
  def __post_init__(self):
252
  self._max_batch_size = self.global_config["embedding_batch_num"]
@@ -282,6 +301,7 @@ class PGVectorStorage(BaseVectorStorage):
282
  "content_vector": json.dumps(item["__vector__"].tolist()),
283
  }
284
  return upsert_sql, data
 
285
  def _upsert_relationships(self, item: dict):
286
  upsert_sql = SQL_TEMPLATES["upsert_relationship"]
287
  data = {
@@ -340,8 +360,6 @@ class PGVectorStorage(BaseVectorStorage):
340
 
341
  await self.db.execute(upsert_sql, data)
342
 
343
-
344
-
345
  async def index_done_callback(self):
346
  logger.info("vector data had been saved into postgresql db!")
347
 
@@ -350,7 +368,7 @@ class PGVectorStorage(BaseVectorStorage):
350
  """从向量数据库中查询数据"""
351
  embeddings = await self.embedding_func([query])
352
  embedding = embeddings[0]
353
- embedding_string = ",".join(map(str, embedding))
354
 
355
  sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
356
  params = {
@@ -361,10 +379,12 @@ class PGVectorStorage(BaseVectorStorage):
361
  results = await self.db.query(sql, params=params, multirows=True)
362
  return results
363
 
 
364
  @dataclass
365
  class PGDocStatusStorage(DocStatusStorage):
366
  """PostgreSQL implementation of document status storage"""
367
- db:PostgreSQLDB = None
 
368
 
369
  def __post_init__(self):
370
  pass
@@ -372,41 +392,47 @@ class PGDocStatusStorage(DocStatusStorage):
372
  async def filter_keys(self, data: list[str]) -> set[str]:
373
  """Return keys that don't exist in storage"""
374
  sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})"
375
- result = await self.db.query(sql, {'workspace': self.db.workspace}, True)
376
  # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
377
  if result is None:
378
  return set(data)
379
  else:
380
- existed = set([element['id'] for element in result])
381
  return set(data) - existed
382
 
383
  async def get_status_counts(self) -> Dict[str, int]:
384
  """Get counts of documents in each status"""
385
- sql = '''SELECT status as "status", COUNT(1) as "count"
386
  FROM LIGHTRAG_DOC_STATUS
387
  where workspace=$1 GROUP BY STATUS
388
- '''
389
- result = await self.db.query(sql, {'workspace': self.db.workspace}, True)
390
  # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
391
  counts = {}
392
  for doc in result:
393
  counts[doc["status"]] = doc["count"]
394
  return counts
395
 
396
- async def get_docs_by_status(self, status: DocStatus) -> Dict[str, DocProcessingStatus]:
 
 
397
  """Get all documents by status"""
398
- sql = 'select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1'
399
- params = {'workspace': self.db.workspace, 'status': status}
400
  result = await self.db.query(sql, params, True)
401
  # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
402
  # Converting to be a dict
403
- return {element["id"]:
404
- DocProcessingStatus(content_summary=element["content_summary"],
405
- content_length=element["content_length"],
406
- status=element["status"],
407
- created_at=element["created_at"],
408
- updated_at=element["updated_at"],
409
- chunks_count=element["chunks_count"]) for element in result}
 
 
 
 
410
 
411
  async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
412
  """Get all failed documents"""
@@ -436,14 +462,17 @@ class PGDocStatusStorage(DocStatusStorage):
436
  updated_at = CURRENT_TIMESTAMP"""
437
  for k, v in data.items():
438
  # chunks_count is optional
439
- await self.db.execute(sql, {
440
- "workspace": self.db.workspace,
441
- "id": k,
442
- "content_summary": v["content_summary"],
443
- "content_length": v["content_length"],
444
- "chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
445
- "status": v["status"],
446
- })
 
 
 
447
  return data
448
 
449
 
@@ -467,7 +496,7 @@ class PGGraphQueryException(Exception):
467
 
468
  @dataclass
469
  class PGGraphStorage(BaseGraphStorage):
470
- db:PostgreSQLDB = None
471
 
472
  @staticmethod
473
  def load_nx_graph(file_name):
@@ -484,7 +513,6 @@ class PGGraphStorage(BaseGraphStorage):
484
  "node2vec": self._node2vec_embed,
485
  }
486
 
487
-
488
  async def index_done_callback(self):
489
  print("KG successfully indexed.")
490
 
@@ -552,7 +580,7 @@ class PGGraphStorage(BaseGraphStorage):
552
 
553
  @staticmethod
554
  def _format_properties(
555
- properties: Dict[str, Any], _id: Union[str, None] = None
556
  ) -> str:
557
  """
558
  Convert a dictionary of properties to a string representation that
@@ -669,7 +697,8 @@ class PGGraphStorage(BaseGraphStorage):
669
 
670
  # get pgsql formatted field names
671
  fields = [
672
- PGGraphStorage._get_col_name(field, idx) for idx, field in enumerate(fields)
 
673
  ]
674
 
675
  # build resulting pgsql relation
@@ -690,7 +719,9 @@ class PGGraphStorage(BaseGraphStorage):
690
  projection=select_str,
691
  )
692
 
693
- async def _query(self, query: str, readonly=True, upsert_edge=False, **params: str) -> List[Dict[str, Any]]:
 
 
694
  """
695
  Query the graph by taking a cypher query, converting it to an
696
  age compatible query, executing it and converting the result
@@ -708,14 +739,25 @@ class PGGraphStorage(BaseGraphStorage):
708
  # execute the query, rolling back on an error
709
  try:
710
  if readonly:
711
- data = await self.db.query(wrapped_query, multirows=True, for_age=True, graph_name=self.graph_name)
 
 
 
 
 
712
  else:
713
  # for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
714
  # It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
715
  if upsert_edge:
716
- data = await self.db.execute(f"{wrapped_query};{wrapped_query};", for_age=True, graph_name=self.graph_name)
 
 
 
 
717
  else:
718
- data = await self.db.execute(wrapped_query, for_age=True, graph_name=self.graph_name)
 
 
719
  except Exception as e:
720
  raise PGGraphQueryException(
721
  {
@@ -819,7 +861,7 @@ class PGGraphStorage(BaseGraphStorage):
819
  return degrees
820
 
821
  async def get_edge(
822
- self, source_node_id: str, target_node_id: str
823
  ) -> Union[dict, None]:
824
  """
825
  Find all edges between nodes of two given labels
@@ -922,7 +964,7 @@ class PGGraphStorage(BaseGraphStorage):
922
  retry=retry_if_exception_type((PGGraphQueryException,)),
923
  )
924
  async def upsert_edge(
925
- self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
926
  ):
927
  """
928
  Upsert an edge and its properties between two nodes identified by their labels.
@@ -935,7 +977,9 @@ class PGGraphStorage(BaseGraphStorage):
935
  source_node_label = source_node_id.strip('"')
936
  target_node_label = target_node_id.strip('"')
937
  edge_properties = edge_data
938
- logger.info(f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}")
 
 
939
 
940
  query = """MATCH (source:`{src_label}`)
941
  WITH source
@@ -1056,7 +1100,6 @@ TABLES = {
1056
  }
1057
 
1058
 
1059
-
1060
  SQL_TEMPLATES = {
1061
  # SQL for KVStorage
1062
  "get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content
@@ -1107,7 +1150,7 @@ SQL_TEMPLATES = {
1107
  "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
1108
  VALUES ($1, $2, $3, $4, $5)
1109
  ON CONFLICT (workspace,id) DO UPDATE
1110
- SET entity_name=EXCLUDED.entity_name,
1111
  content=EXCLUDED.content,
1112
  content_vector=EXCLUDED.content_vector,
1113
  updatetime=CURRENT_TIMESTAMP
@@ -1136,5 +1179,5 @@ SQL_TEMPLATES = {
1136
  (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1137
  FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
1138
  WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1139
- """
1140
  }
 
19
  from ..utils import logger
20
  from ..base import (
21
  BaseKVStorage,
22
+ BaseVectorStorage,
23
+ DocStatusStorage,
24
+ DocStatus,
25
+ DocProcessingStatus,
26
+ BaseGraphStorage,
27
  )
28
 
29
  if sys.platform.startswith("win"):
 
40
  self.user = config.get("user", "postgres")
41
  self.password = config.get("password", None)
42
  self.database = config.get("database", "postgres")
43
+ self.workspace = config.get("workspace", "default")
44
  self.max = 12
45
  self.increment = 1
46
  logger.info(f"Using the label {self.workspace} for PostgreSQL as identifier")
47
 
48
  if self.user is None or self.password is None or self.database is None:
49
+ raise ValueError(
50
+ "Missing database user, password, or database in addon_params"
51
+ )
52
 
53
  async def initdb(self):
54
  try:
 
59
  host=self.host,
60
  port=self.port,
61
  min_size=1,
62
+ max_size=self.max,
63
  )
64
 
65
+ logger.info(
66
+ f"Connected to PostgreSQL database at {self.host}:{self.port}/{self.database}"
67
+ )
68
  except Exception as e:
69
+ logger.error(
70
+ f"Failed to connect to PostgreSQL database at {self.host}:{self.port}/{self.database}"
71
+ )
72
  logger.error(f"PostgreSQL database error: {e}")
73
  raise
74
 
 
88
 
89
  logger.info("Finished checking all tables in PostgreSQL database")
90
 
 
91
  async def query(
92
+ self,
93
+ sql: str,
94
+ params: dict = None,
95
+ multirows: bool = False,
96
+ for_age: bool = False,
97
+ graph_name: str = None,
98
  ) -> Union[dict, None, list[dict]]:
99
  async with self.pool.acquire() as connection:
100
  try:
 
124
  print(params)
125
  raise
126
 
127
+ async def execute(
128
+ self,
129
+ sql: str,
130
+ data: Union[list, dict] = None,
131
+ for_age: bool = False,
132
+ graph_name: str = None,
133
+ ):
134
  try:
135
  async with self.pool.acquire() as connection:
136
  if for_age:
 
149
  @staticmethod
150
  async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
151
  try:
152
+ await conn.execute('SET search_path = ag_catalog, "$user", public')
153
  await conn.execute(f"""select create_graph('{graph_name}')""")
154
  except asyncpg.exceptions.InvalidSchemaNameError:
155
  pass
 
157
 
158
  @dataclass
159
  class PGKVStorage(BaseKVStorage):
160
+ db: PostgreSQLDB = None
161
 
162
  def __post_init__(self):
163
  self._data = {}
 
199
  dict_res[mode] = {}
200
  for row in array_res:
201
  dict_res[row["mode"]][row["id"]] = row
202
+ res = [{k: v} for k, v in dict_res.items()]
203
  else:
204
  res = await self.db.query(sql, params, multirows=True)
205
  if res:
 
210
  async def filter_keys(self, keys: List[str]) -> Set[str]:
211
  """Filter out duplicated content"""
212
  sql = SQL_TEMPLATES["filter_keys"].format(
213
+ table_name=NAMESPACE_TABLE_MAP[self.namespace],
214
+ ids=",".join([f"'{id}'" for id in keys]),
215
  )
216
  params = {"workspace": self.db.workspace}
217
  try:
 
227
  print(sql)
228
  print(params)
229
 
 
230
  ################ INSERT METHODS ################
231
  async def upsert(self, data: Dict[str, dict]):
232
  left_data = {k: v for k, v in data.items() if k not in self._data}
 
265
  @dataclass
266
  class PGVectorStorage(BaseVectorStorage):
267
  cosine_better_than_threshold: float = 0.2
268
+ db: PostgreSQLDB = None
269
 
270
  def __post_init__(self):
271
  self._max_batch_size = self.global_config["embedding_batch_num"]
 
301
  "content_vector": json.dumps(item["__vector__"].tolist()),
302
  }
303
  return upsert_sql, data
304
+
305
  def _upsert_relationships(self, item: dict):
306
  upsert_sql = SQL_TEMPLATES["upsert_relationship"]
307
  data = {
 
360
 
361
  await self.db.execute(upsert_sql, data)
362
 
 
 
363
  async def index_done_callback(self):
364
  logger.info("vector data had been saved into postgresql db!")
365
 
 
368
  """从向量数据库中查询数据"""
369
  embeddings = await self.embedding_func([query])
370
  embedding = embeddings[0]
371
+ embedding_string = ",".join(map(str, embedding))
372
 
373
  sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
374
  params = {
 
379
  results = await self.db.query(sql, params=params, multirows=True)
380
  return results
381
 
382
+
383
  @dataclass
384
  class PGDocStatusStorage(DocStatusStorage):
385
  """PostgreSQL implementation of document status storage"""
386
+
387
+ db: PostgreSQLDB = None
388
 
389
  def __post_init__(self):
390
  pass
 
392
  async def filter_keys(self, data: list[str]) -> set[str]:
393
  """Return keys that don't exist in storage"""
394
  sql = f"SELECT id FROM LIGHTRAG_DOC_STATUS WHERE workspace=$1 AND id IN ({",".join([f"'{_id}'" for _id in data])})"
395
+ result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
396
  # The result is like [{'id': 'id1'}, {'id': 'id2'}, ...].
397
  if result is None:
398
  return set(data)
399
  else:
400
+ existed = set([element["id"] for element in result])
401
  return set(data) - existed
402
 
403
  async def get_status_counts(self) -> Dict[str, int]:
404
  """Get counts of documents in each status"""
405
+ sql = """SELECT status as "status", COUNT(1) as "count"
406
  FROM LIGHTRAG_DOC_STATUS
407
  where workspace=$1 GROUP BY STATUS
408
+ """
409
+ result = await self.db.query(sql, {"workspace": self.db.workspace}, True)
410
  # Result is like [{'status': 'PENDING', 'count': 1}, {'status': 'PROCESSING', 'count': 2}, ...]
411
  counts = {}
412
  for doc in result:
413
  counts[doc["status"]] = doc["count"]
414
  return counts
415
 
416
+ async def get_docs_by_status(
417
+ self, status: DocStatus
418
+ ) -> Dict[str, DocProcessingStatus]:
419
  """Get all documents by status"""
420
+ sql = "select * from LIGHTRAG_DOC_STATUS where workspace=$1 and status=$1"
421
+ params = {"workspace": self.db.workspace, "status": status}
422
  result = await self.db.query(sql, params, True)
423
  # Result is like [{'id': 'id1', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, {'id': 'id2', 'status': 'PENDING', 'updated_at': '2023-07-01 00:00:00'}, ...]
424
  # Converting to be a dict
425
+ return {
426
+ element["id"]: DocProcessingStatus(
427
+ content_summary=element["content_summary"],
428
+ content_length=element["content_length"],
429
+ status=element["status"],
430
+ created_at=element["created_at"],
431
+ updated_at=element["updated_at"],
432
+ chunks_count=element["chunks_count"],
433
+ )
434
+ for element in result
435
+ }
436
 
437
  async def get_failed_docs(self) -> Dict[str, DocProcessingStatus]:
438
  """Get all failed documents"""
 
462
  updated_at = CURRENT_TIMESTAMP"""
463
  for k, v in data.items():
464
  # chunks_count is optional
465
+ await self.db.execute(
466
+ sql,
467
+ {
468
+ "workspace": self.db.workspace,
469
+ "id": k,
470
+ "content_summary": v["content_summary"],
471
+ "content_length": v["content_length"],
472
+ "chunks_count": v["chunks_count"] if "chunks_count" in v else -1,
473
+ "status": v["status"],
474
+ },
475
+ )
476
  return data
477
 
478
 
 
496
 
497
  @dataclass
498
  class PGGraphStorage(BaseGraphStorage):
499
+ db: PostgreSQLDB = None
500
 
501
  @staticmethod
502
  def load_nx_graph(file_name):
 
513
  "node2vec": self._node2vec_embed,
514
  }
515
 
 
516
  async def index_done_callback(self):
517
  print("KG successfully indexed.")
518
 
 
580
 
581
  @staticmethod
582
  def _format_properties(
583
+ properties: Dict[str, Any], _id: Union[str, None] = None
584
  ) -> str:
585
  """
586
  Convert a dictionary of properties to a string representation that
 
697
 
698
  # get pgsql formatted field names
699
  fields = [
700
+ PGGraphStorage._get_col_name(field, idx)
701
+ for idx, field in enumerate(fields)
702
  ]
703
 
704
  # build resulting pgsql relation
 
719
  projection=select_str,
720
  )
721
 
722
+ async def _query(
723
+ self, query: str, readonly=True, upsert_edge=False, **params: str
724
+ ) -> List[Dict[str, Any]]:
725
  """
726
  Query the graph by taking a cypher query, converting it to an
727
  age compatible query, executing it and converting the result
 
739
  # execute the query, rolling back on an error
740
  try:
741
  if readonly:
742
+ data = await self.db.query(
743
+ wrapped_query,
744
+ multirows=True,
745
+ for_age=True,
746
+ graph_name=self.graph_name,
747
+ )
748
  else:
749
  # for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
750
  # It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
751
  if upsert_edge:
752
+ data = await self.db.execute(
753
+ f"{wrapped_query};{wrapped_query};",
754
+ for_age=True,
755
+ graph_name=self.graph_name,
756
+ )
757
  else:
758
+ data = await self.db.execute(
759
+ wrapped_query, for_age=True, graph_name=self.graph_name
760
+ )
761
  except Exception as e:
762
  raise PGGraphQueryException(
763
  {
 
861
  return degrees
862
 
863
  async def get_edge(
864
+ self, source_node_id: str, target_node_id: str
865
  ) -> Union[dict, None]:
866
  """
867
  Find all edges between nodes of two given labels
 
964
  retry=retry_if_exception_type((PGGraphQueryException,)),
965
  )
966
  async def upsert_edge(
967
+ self, source_node_id: str, target_node_id: str, edge_data: Dict[str, Any]
968
  ):
969
  """
970
  Upsert an edge and its properties between two nodes identified by their labels.
 
977
  source_node_label = source_node_id.strip('"')
978
  target_node_label = target_node_id.strip('"')
979
  edge_properties = edge_data
980
+ logger.info(
981
+ f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}"
982
+ )
983
 
984
  query = """MATCH (source:`{src_label}`)
985
  WITH source
 
1100
  }
1101
 
1102
 
 
1103
  SQL_TEMPLATES = {
1104
  # SQL for KVStorage
1105
  "get_by_id_full_docs": """SELECT id, COALESCE(content, '') as content
 
1150
  "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
1151
  VALUES ($1, $2, $3, $4, $5)
1152
  ON CONFLICT (workspace,id) DO UPDATE
1153
+ SET entity_name=EXCLUDED.entity_name,
1154
  content=EXCLUDED.content,
1155
  content_vector=EXCLUDED.content_vector,
1156
  updatetime=CURRENT_TIMESTAMP
 
1179
  (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1180
  FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
1181
  WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1182
+ """,
1183
  }
lightrag/kg/postgres_impl_test.py CHANGED
@@ -1,33 +1,39 @@
1
  import asyncio
2
  import asyncpg
3
- import sys, os
 
4
 
5
  import psycopg
6
  from psycopg_pool import AsyncConnectionPool
7
  from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
8
 
9
- DB="rag"
10
- USER="rag"
11
- PASSWORD="rag"
12
- HOST="localhost"
13
- PORT="15432"
14
  os.environ["AGE_GRAPH_NAME"] = "dickens"
15
 
16
  if sys.platform.startswith("win"):
17
  import asyncio.windows_events
 
18
  asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
19
 
 
20
  async def get_pool():
21
  return await asyncpg.create_pool(
22
  f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
23
  min_size=10,
24
  max_size=10,
25
  max_queries=5000,
26
- max_inactive_connection_lifetime=300.0
27
  )
28
 
 
29
  async def main1():
30
- connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
 
 
31
  pool = AsyncConnectionPool(connection_string, open=False)
32
  await pool.open()
33
 
@@ -36,18 +42,19 @@ async def main1():
36
  async with conn.cursor() as curs:
37
  try:
38
  await curs.execute('SET search_path = ag_catalog, "$user", public')
39
- await curs.execute(f"SELECT create_graph('dickens-2')")
40
  await conn.commit()
41
  print("create_graph success")
42
  except (
43
- psycopg.errors.InvalidSchemaName,
44
- psycopg.errors.UniqueViolation,
45
  ):
46
  print("create_graph already exists")
47
  await conn.rollback()
48
  finally:
49
  pass
50
 
 
51
  db = PostgreSQLDB(
52
  config={
53
  "host": "localhost",
@@ -58,6 +65,7 @@ db = PostgreSQLDB(
58
  }
59
  )
60
 
 
61
  async def query_with_age():
62
  await db.initdb()
63
  graph = PGGraphStorage(
@@ -69,6 +77,7 @@ async def query_with_age():
69
  res = await graph.get_node('"CHRISTMAS-TIME"')
70
  print("Node is: ", res)
71
 
 
72
  async def create_edge_with_age():
73
  await db.initdb()
74
  graph = PGGraphStorage(
@@ -89,31 +98,28 @@ async def create_edge_with_age():
89
  "source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
90
  },
91
  )
92
- res = await graph.get_edge('THE CRATCHITS', '"THE GIRLS"')
93
  print("Edge is: ", res)
94
 
95
 
96
  async def main():
97
  pool = await get_pool()
98
- # 如果还有其它什么特殊参数,也可以直接往里面传递,因为设置了 **connect_kwargs
99
- # 专门用来设置一些数据库独有的某些属性
100
- # 从池子中取出一个连接
101
  sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
102
  # cypher = "MATCH (n:how_are_you_doing) RETURN n"
103
  async with pool.acquire() as conn:
104
- try:
105
- await conn.execute("""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""")
106
- except asyncpg.exceptions.InvalidSchemaNameError:
107
- print("create_graph already exists")
108
- # stmt = await conn.prepare(sql)
109
- row = await conn.fetch(sql)
110
- print("row is: ", row)
 
 
111
 
112
- # 解决办法就是起一个别名
113
- row = await conn.fetchrow("select '100'::int + 200 as result")
114
- print(row) # <Record result=300>
115
- # 我们的连接是从池子里面取出的,上下文结束之后会自动放回到到池子里面
116
 
117
 
118
- if __name__ == '__main__':
119
  asyncio.run(query_with_age())
 
1
  import asyncio
2
  import asyncpg
3
+ import sys
4
+ import os
5
 
6
  import psycopg
7
  from psycopg_pool import AsyncConnectionPool
8
  from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
9
 
10
+ DB = "rag"
11
+ USER = "rag"
12
+ PASSWORD = "rag"
13
+ HOST = "localhost"
14
+ PORT = "15432"
15
  os.environ["AGE_GRAPH_NAME"] = "dickens"
16
 
17
  if sys.platform.startswith("win"):
18
  import asyncio.windows_events
19
+
20
  asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
21
 
22
+
23
  async def get_pool():
24
  return await asyncpg.create_pool(
25
  f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
26
  min_size=10,
27
  max_size=10,
28
  max_queries=5000,
29
+ max_inactive_connection_lifetime=300.0,
30
  )
31
 
32
+
33
  async def main1():
34
+ connection_string = (
35
+ f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
36
+ )
37
  pool = AsyncConnectionPool(connection_string, open=False)
38
  await pool.open()
39
 
 
42
  async with conn.cursor() as curs:
43
  try:
44
  await curs.execute('SET search_path = ag_catalog, "$user", public')
45
+ await curs.execute("SELECT create_graph('dickens-2')")
46
  await conn.commit()
47
  print("create_graph success")
48
  except (
49
+ psycopg.errors.InvalidSchemaName,
50
+ psycopg.errors.UniqueViolation,
51
  ):
52
  print("create_graph already exists")
53
  await conn.rollback()
54
  finally:
55
  pass
56
 
57
+
58
  db = PostgreSQLDB(
59
  config={
60
  "host": "localhost",
 
65
  }
66
  )
67
 
68
+
69
  async def query_with_age():
70
  await db.initdb()
71
  graph = PGGraphStorage(
 
77
  res = await graph.get_node('"CHRISTMAS-TIME"')
78
  print("Node is: ", res)
79
 
80
+
81
  async def create_edge_with_age():
82
  await db.initdb()
83
  graph = PGGraphStorage(
 
98
  "source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
99
  },
100
  )
101
+ res = await graph.get_edge("THE CRATCHITS", '"THE GIRLS"')
102
  print("Edge is: ", res)
103
 
104
 
105
  async def main():
106
  pool = await get_pool()
 
 
 
107
  sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
108
  # cypher = "MATCH (n:how_are_you_doing) RETURN n"
109
  async with pool.acquire() as conn:
110
+ try:
111
+ await conn.execute(
112
+ """SET search_path = ag_catalog, "$user", public;select create_graph('dickens')"""
113
+ )
114
+ except asyncpg.exceptions.InvalidSchemaNameError:
115
+ print("create_graph already exists")
116
+ # stmt = await conn.prepare(sql)
117
+ row = await conn.fetch(sql)
118
+ print("row is: ", row)
119
 
120
+ row = await conn.fetchrow("select '100'::int + 200 as result")
121
+ print(row) # <Record result=300>
 
 
122
 
123
 
124
+ if __name__ == "__main__":
125
  asyncio.run(query_with_age())
requirements.txt CHANGED
@@ -1,6 +1,8 @@
1
  accelerate
2
  aioboto3~=13.3.0
 
3
  aiohttp~=3.11.11
 
4
 
5
  # database packages
6
  graspologic
@@ -9,14 +11,20 @@ hnswlib
9
  nano-vectordb
10
  neo4j~=5.27.0
11
  networkx~=3.2.1
 
 
12
  ollama~=0.4.4
13
  openai~=1.58.1
14
  oracledb
 
15
  psycopg[binary,pool]~=3.2.3
 
16
  pymilvus
17
  pymongo
18
  pymysql
 
19
  pyvis~=0.3.2
 
20
  # lmdeploy[all]
21
  sqlalchemy~=2.0.36
22
  tenacity~=9.0.0
@@ -25,14 +33,6 @@ tenacity~=9.0.0
25
  # LLM packages
26
  tiktoken~=0.8.0
27
  torch~=2.5.1+cu121
 
28
  transformers~=4.47.1
29
  xxhash
30
-
31
- numpy~=2.2.0
32
- aiofiles~=24.1.0
33
- pydantic~=2.10.4
34
- python-dotenv~=1.0.1
35
- psycopg-pool~=3.2.4
36
- tqdm~=4.67.1
37
- asyncpg~=0.30.0
38
- setuptools~=70.0.0
 
1
  accelerate
2
  aioboto3~=13.3.0
3
+ aiofiles~=24.1.0
4
  aiohttp~=3.11.11
5
+ asyncpg~=0.30.0
6
 
7
  # database packages
8
  graspologic
 
11
  nano-vectordb
12
  neo4j~=5.27.0
13
  networkx~=3.2.1
14
+
15
+ numpy~=2.2.0
16
  ollama~=0.4.4
17
  openai~=1.58.1
18
  oracledb
19
+ psycopg-pool~=3.2.4
20
  psycopg[binary,pool]~=3.2.3
21
+ pydantic~=2.10.4
22
  pymilvus
23
  pymongo
24
  pymysql
25
+ python-dotenv~=1.0.1
26
  pyvis~=0.3.2
27
+ setuptools~=70.0.0
28
  # lmdeploy[all]
29
  sqlalchemy~=2.0.36
30
  tenacity~=9.0.0
 
33
  # LLM packages
34
  tiktoken~=0.8.0
35
  torch~=2.5.1+cu121
36
+ tqdm~=4.67.1
37
  transformers~=4.47.1
38
  xxhash