samuel-z-chen commited on
Commit
5876f53
·
1 Parent(s): 359e407

Fix the bug of AGE processing

Browse files
examples/lightrag_zhipu_postgres_demo.py CHANGED
@@ -6,7 +6,7 @@ import time
6
  from dotenv import load_dotenv
7
 
8
  from lightrag import LightRAG, QueryParam
9
- from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
10
  from lightrag.llm import ollama_embedding, zhipu_complete
11
  from lightrag.utils import EmbeddingFunc
12
 
@@ -67,7 +67,6 @@ async def main():
67
  rag.entities_vdb.db = postgres_db
68
  rag.graph_storage_cls.db = postgres_db
69
  rag.chunk_entity_relation_graph.db = postgres_db
70
- await rag.chunk_entity_relation_graph.check_graph_exists()
71
  # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
72
  rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
73
 
@@ -103,21 +102,6 @@ async def main():
103
  )
104
  print(f"Hybrid Query Time: {time.time() - start_time} seconds")
105
 
106
- print("**** Start Stream Query ****")
107
- start_time = time.time()
108
- # stream response
109
- resp = await rag.aquery(
110
- "What are the top themes in this story?",
111
- param=QueryParam(mode="hybrid", stream=True),
112
- )
113
- print(f"Stream Query Time: {time.time() - start_time} seconds")
114
- print("**** Done Stream Query ****")
115
-
116
- if inspect.isasyncgen(resp):
117
- asyncio.run(print_stream(resp))
118
- else:
119
- print(resp)
120
-
121
 
122
  if __name__ == "__main__":
123
  asyncio.run(main())
 
6
  from dotenv import load_dotenv
7
 
8
  from lightrag import LightRAG, QueryParam
9
+ from lightrag.kg.postgres_impl import PostgreSQLDB
10
  from lightrag.llm import ollama_embedding, zhipu_complete
11
  from lightrag.utils import EmbeddingFunc
12
 
 
67
  rag.entities_vdb.db = postgres_db
68
  rag.graph_storage_cls.db = postgres_db
69
  rag.chunk_entity_relation_graph.db = postgres_db
 
70
  # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
71
  rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
72
 
 
102
  )
103
  print(f"Hybrid Query Time: {time.time() - start_time} seconds")
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  if __name__ == "__main__":
107
  asyncio.run(main())
lightrag/kg/postgres_impl.py CHANGED
@@ -81,12 +81,12 @@ class PostgreSQLDB:
81
 
82
 
83
  async def query(
84
- self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False
85
  ) -> Union[dict, None, list[dict]]:
86
  async with self.pool.acquire() as connection:
87
  try:
88
  if for_age:
89
- await connection.execute('SET search_path = ag_catalog, "$user", public')
90
  if params:
91
  rows = await connection.fetch(sql, *params.values())
92
  else:
@@ -95,10 +95,7 @@ class PostgreSQLDB:
95
  if multirows:
96
  if rows:
97
  columns = [col for col in rows[0].keys()]
98
- # print("columns", columns.__class__, columns)
99
- # print("rows", rows)
100
  data = [dict(zip(columns, row)) for row in rows]
101
- # print("data", data)
102
  else:
103
  data = []
104
  else:
@@ -114,11 +111,11 @@ class PostgreSQLDB:
114
  print(params)
115
  raise
116
 
117
- async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False):
118
  try:
119
  async with self.pool.acquire() as connection:
120
  if for_age:
121
- await connection.execute('SET search_path = ag_catalog, "$user", public')
122
 
123
  if data is None:
124
  await connection.execute(sql)
@@ -130,6 +127,14 @@ class PostgreSQLDB:
130
  print(data)
131
  raise
132
 
 
 
 
 
 
 
 
 
133
 
134
  @dataclass
135
  class PGKVStorage(BaseKVStorage):
@@ -346,18 +351,14 @@ class PGVectorStorage(BaseVectorStorage):
346
  embeddings = await self.embedding_func([query])
347
  embedding = embeddings[0]
348
  embedding_string = ",".join(map(str, embedding))
349
- # print("Namespace", self.namespace)
350
 
351
  sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
352
- # print("sql is: ", sql)
353
  params = {
354
  "workspace": self.db.workspace,
355
  "better_than_threshold": self.cosine_better_than_threshold,
356
  "top_k": top_k,
357
  }
358
- # print("params is: ", params)
359
  results = await self.db.query(sql, params=params, multirows=True)
360
- print("vector search result:", results)
361
  return results
362
 
363
  @dataclass
@@ -487,19 +488,6 @@ class PGGraphStorage(BaseGraphStorage):
487
  async def index_done_callback(self):
488
  print("KG successfully indexed.")
489
 
490
- async def check_graph_exists(self):
491
- try:
492
- res = await self.db.query(f"SELECT * FROM ag_catalog.ag_graph WHERE name = '{self.graph_name}'")
493
- if res:
494
- logger.info(f"Graph {self.graph_name} exists.")
495
- else:
496
- logger.info(f"Graph {self.graph_name} does not exist. Creating...")
497
- await self.db.execute(f"SELECT create_graph('{self.graph_name}')", for_age=True)
498
- logger.info(f"Graph {self.graph_name} created.")
499
- except Exception as e:
500
- logger.info(f"Failed to check/create graph {self.graph_name}:", e)
501
- raise e
502
-
503
  @staticmethod
504
  def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
505
  """
@@ -572,7 +560,7 @@ class PGGraphStorage(BaseGraphStorage):
572
 
573
  Args:
574
  properties (Dict[str,str]): a dictionary containing node/edge properties
575
- id (Union[str, None]): the id of the node or None if none exists
576
 
577
  Returns:
578
  str: the properties dictionary as a properly formatted string
@@ -591,7 +579,7 @@ class PGGraphStorage(BaseGraphStorage):
591
  @staticmethod
592
  def _encode_graph_label(label: str) -> str:
593
  """
594
- Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
595
 
596
  Args:
597
  label (str): the original label
@@ -604,7 +592,7 @@ class PGGraphStorage(BaseGraphStorage):
604
  @staticmethod
605
  def _decode_graph_label(encoded_label: str) -> str:
606
  """
607
- Since AGE suports only alphanumerical labels, we will encode generic label as HEX string
608
 
609
  Args:
610
  encoded_label (str): the encoded label
@@ -656,8 +644,8 @@ class PGGraphStorage(BaseGraphStorage):
656
 
657
  # pgsql template
658
  template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
659
- {query}
660
- $$) AS ({fields});"""
661
 
662
  # if there are any returned fields they must be added to the pgsql query
663
  if "return" in query.lower():
@@ -702,7 +690,7 @@ class PGGraphStorage(BaseGraphStorage):
702
  projection=select_str,
703
  )
704
 
705
- async def _query(self, query: str, readonly=True, **params: str) -> List[Dict[str, Any]]:
706
  """
707
  Query the graph by taking a cypher query, converting it to an
708
  age compatible query, executing it and converting the result
@@ -720,9 +708,14 @@ class PGGraphStorage(BaseGraphStorage):
720
  # execute the query, rolling back on an error
721
  try:
722
  if readonly:
723
- data = await self.db.query(wrapped_query, multirows=True, for_age=True)
724
  else:
725
- data = await self.db.execute(wrapped_query, for_age=True)
 
 
 
 
 
726
  except Exception as e:
727
  raise PGGraphQueryException(
728
  {
@@ -743,9 +736,7 @@ class PGGraphStorage(BaseGraphStorage):
743
  async def has_node(self, node_id: str) -> bool:
744
  entity_name_label = node_id.strip('"')
745
 
746
- query = """
747
- MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists
748
- """
749
  params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
750
  single_result = (await self._query(query, **params))[0]
751
  logger.debug(
@@ -761,10 +752,8 @@ class PGGraphStorage(BaseGraphStorage):
761
  entity_name_label_source = source_node_id.strip('"')
762
  entity_name_label_target = target_node_id.strip('"')
763
 
764
- query = """
765
- MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
766
- RETURN COUNT(r) > 0 AS edge_exists
767
- """
768
  params = {
769
  "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
770
  "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
@@ -780,9 +769,7 @@ class PGGraphStorage(BaseGraphStorage):
780
 
781
  async def get_node(self, node_id: str) -> Union[dict, None]:
782
  entity_name_label = node_id.strip('"')
783
- query = """
784
- MATCH (n:`{label}`) RETURN n
785
- """
786
  params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
787
  record = await self._query(query, **params)
788
  if record:
@@ -800,10 +787,7 @@ class PGGraphStorage(BaseGraphStorage):
800
  async def node_degree(self, node_id: str) -> int:
801
  entity_name_label = node_id.strip('"')
802
 
803
- query = """
804
- MATCH (n:`{label}`)-[]->(x)
805
- RETURN count(x) AS total_edge_count
806
- """
807
  params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
808
  record = (await self._query(query, **params))[0]
809
  if record:
@@ -841,8 +825,8 @@ class PGGraphStorage(BaseGraphStorage):
841
  Find all edges between nodes of two given labels
842
 
843
  Args:
844
- source_node_label (str): Label of the source nodes
845
- target_node_label (str): Label of the target nodes
846
 
847
  Returns:
848
  list: List of all relationships/edges found
@@ -850,11 +834,9 @@ class PGGraphStorage(BaseGraphStorage):
850
  entity_name_label_source = source_node_id.strip('"')
851
  entity_name_label_target = target_node_id.strip('"')
852
 
853
- query = """
854
- MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
855
  RETURN properties(r) as edge_properties
856
- LIMIT 1
857
- """
858
  params = {
859
  "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
860
  "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
@@ -877,11 +859,9 @@ class PGGraphStorage(BaseGraphStorage):
877
  """
878
  node_label = source_node_id.strip('"')
879
 
880
- query = """
881
- MATCH (n:`{label}`)
882
  OPTIONAL MATCH (n)-[r]-(connected)
883
- RETURN n, r, connected
884
- """
885
  params = {"label": PGGraphStorage._encode_graph_label(node_label)}
886
  results = await self._query(query, **params)
887
  edges = []
@@ -919,10 +899,8 @@ class PGGraphStorage(BaseGraphStorage):
919
  label = node_id.strip('"')
920
  properties = node_data
921
 
922
- query = """
923
- MERGE (n:`{label}`)
924
- SET n += {properties}
925
- """
926
  params = {
927
  "label": PGGraphStorage._encode_graph_label(label),
928
  "properties": PGGraphStorage._format_properties(properties),
@@ -957,22 +935,22 @@ class PGGraphStorage(BaseGraphStorage):
957
  source_node_label = source_node_id.strip('"')
958
  target_node_label = target_node_id.strip('"')
959
  edge_properties = edge_data
 
960
 
961
- query = """
962
- MATCH (source:`{src_label}`)
963
  WITH source
964
  MATCH (target:`{tgt_label}`)
965
  MERGE (source)-[r:DIRECTED]->(target)
966
  SET r += {properties}
967
- RETURN r
968
- """
969
  params = {
970
  "src_label": PGGraphStorage._encode_graph_label(source_node_label),
971
  "tgt_label": PGGraphStorage._encode_graph_label(target_node_label),
972
  "properties": PGGraphStorage._format_properties(edge_properties),
973
  }
 
974
  try:
975
- await self._query(query, readonly=False, **params)
976
  logger.debug(
977
  "Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
978
  source_node_label,
@@ -1127,7 +1105,7 @@ SQL_TEMPLATES = {
1127
  updatetime = CURRENT_TIMESTAMP
1128
  """,
1129
  "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
1130
- VALUES ($1, $2, $3, $4, $5, $6)
1131
  ON CONFLICT (workspace,id) DO UPDATE
1132
  SET entity_name=EXCLUDED.entity_name,
1133
  content=EXCLUDED.content,
 
81
 
82
 
83
  async def query(
84
+ self, sql: str, params: dict = None, multirows: bool = False, for_age: bool = False, graph_name: str = None
85
  ) -> Union[dict, None, list[dict]]:
86
  async with self.pool.acquire() as connection:
87
  try:
88
  if for_age:
89
+ await PostgreSQLDB._prerequisite(connection, graph_name)
90
  if params:
91
  rows = await connection.fetch(sql, *params.values())
92
  else:
 
95
  if multirows:
96
  if rows:
97
  columns = [col for col in rows[0].keys()]
 
 
98
  data = [dict(zip(columns, row)) for row in rows]
 
99
  else:
100
  data = []
101
  else:
 
111
  print(params)
112
  raise
113
 
114
+ async def execute(self, sql: str, data: Union[list, dict] = None, for_age: bool = False, graph_name: str = None):
115
  try:
116
  async with self.pool.acquire() as connection:
117
  if for_age:
118
+ await PostgreSQLDB._prerequisite(connection, graph_name)
119
 
120
  if data is None:
121
  await connection.execute(sql)
 
127
  print(data)
128
  raise
129
 
130
+ @staticmethod
131
+ async def _prerequisite(conn: asyncpg.Connection, graph_name: str):
132
+ try:
133
+ await conn.execute(f'SET search_path = ag_catalog, "$user", public')
134
+ await conn.execute(f"""select create_graph('{graph_name}')""")
135
+ except asyncpg.exceptions.InvalidSchemaNameError:
136
+ pass
137
+
138
 
139
  @dataclass
140
  class PGKVStorage(BaseKVStorage):
 
351
  embeddings = await self.embedding_func([query])
352
  embedding = embeddings[0]
353
  embedding_string = ",".join(map(str, embedding))
 
354
 
355
  sql = SQL_TEMPLATES[self.namespace].format(embedding_string=embedding_string)
 
356
  params = {
357
  "workspace": self.db.workspace,
358
  "better_than_threshold": self.cosine_better_than_threshold,
359
  "top_k": top_k,
360
  }
 
361
  results = await self.db.query(sql, params=params, multirows=True)
 
362
  return results
363
 
364
  @dataclass
 
488
  async def index_done_callback(self):
489
  print("KG successfully indexed.")
490
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  @staticmethod
492
  def _record_to_dict(record: asyncpg.Record) -> Dict[str, Any]:
493
  """
 
560
 
561
  Args:
562
  properties (Dict[str,str]): a dictionary containing node/edge properties
563
+ _id (Union[str, None]): the id of the node or None if none exists
564
 
565
  Returns:
566
  str: the properties dictionary as a properly formatted string
 
579
  @staticmethod
580
  def _encode_graph_label(label: str) -> str:
581
  """
582
+ Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
583
 
584
  Args:
585
  label (str): the original label
 
592
  @staticmethod
593
  def _decode_graph_label(encoded_label: str) -> str:
594
  """
595
+ Since AGE supports only alphanumerical labels, we will encode generic label as HEX string
596
 
597
  Args:
598
  encoded_label (str): the encoded label
 
644
 
645
  # pgsql template
646
  template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
647
+ {query}
648
+ $$) AS ({fields})"""
649
 
650
  # if there are any returned fields they must be added to the pgsql query
651
  if "return" in query.lower():
 
690
  projection=select_str,
691
  )
692
 
693
+ async def _query(self, query: str, readonly=True, upsert_edge=False, **params: str) -> List[Dict[str, Any]]:
694
  """
695
  Query the graph by taking a cypher query, converting it to an
696
  age compatible query, executing it and converting the result
 
708
  # execute the query, rolling back on an error
709
  try:
710
  if readonly:
711
+ data = await self.db.query(wrapped_query, multirows=True, for_age=True, graph_name=self.graph_name)
712
  else:
713
+ # for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
714
+ # It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
715
+ if upsert_edge:
716
+ data = await self.db.execute(f"{wrapped_query};{wrapped_query};", for_age=True, graph_name=self.graph_name)
717
+ else:
718
+ data = await self.db.execute(wrapped_query, for_age=True, graph_name=self.graph_name)
719
  except Exception as e:
720
  raise PGGraphQueryException(
721
  {
 
736
  async def has_node(self, node_id: str) -> bool:
737
  entity_name_label = node_id.strip('"')
738
 
739
+ query = """MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"""
 
 
740
  params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
741
  single_result = (await self._query(query, **params))[0]
742
  logger.debug(
 
752
  entity_name_label_source = source_node_id.strip('"')
753
  entity_name_label_target = target_node_id.strip('"')
754
 
755
+ query = """MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
756
+ RETURN COUNT(r) > 0 AS edge_exists"""
 
 
757
  params = {
758
  "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
759
  "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
 
769
 
770
  async def get_node(self, node_id: str) -> Union[dict, None]:
771
  entity_name_label = node_id.strip('"')
772
+ query = """MATCH (n:`{label}`) RETURN n"""
 
 
773
  params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
774
  record = await self._query(query, **params)
775
  if record:
 
787
  async def node_degree(self, node_id: str) -> int:
788
  entity_name_label = node_id.strip('"')
789
 
790
+ query = """MATCH (n:`{label}`)-[]->(x) RETURN count(x) AS total_edge_count"""
 
 
 
791
  params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
792
  record = (await self._query(query, **params))[0]
793
  if record:
 
825
  Find all edges between nodes of two given labels
826
 
827
  Args:
828
+ source_node_id (str): Label of the source nodes
829
+ target_node_id (str): Label of the target nodes
830
 
831
  Returns:
832
  list: List of all relationships/edges found
 
834
  entity_name_label_source = source_node_id.strip('"')
835
  entity_name_label_target = target_node_id.strip('"')
836
 
837
+ query = """MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
 
838
  RETURN properties(r) as edge_properties
839
+ LIMIT 1"""
 
840
  params = {
841
  "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
842
  "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
 
859
  """
860
  node_label = source_node_id.strip('"')
861
 
862
+ query = """MATCH (n:`{label}`)
 
863
  OPTIONAL MATCH (n)-[r]-(connected)
864
+ RETURN n, r, connected"""
 
865
  params = {"label": PGGraphStorage._encode_graph_label(node_label)}
866
  results = await self._query(query, **params)
867
  edges = []
 
899
  label = node_id.strip('"')
900
  properties = node_data
901
 
902
+ query = """MERGE (n:`{label}`)
903
+ SET n += {properties}"""
 
 
904
  params = {
905
  "label": PGGraphStorage._encode_graph_label(label),
906
  "properties": PGGraphStorage._format_properties(properties),
 
935
  source_node_label = source_node_id.strip('"')
936
  target_node_label = target_node_id.strip('"')
937
  edge_properties = edge_data
938
+ logger.info(f"-- inserting edge: {source_node_label} -> {target_node_label}: {edge_data}")
939
 
940
+ query = """MATCH (source:`{src_label}`)
 
941
  WITH source
942
  MATCH (target:`{tgt_label}`)
943
  MERGE (source)-[r:DIRECTED]->(target)
944
  SET r += {properties}
945
+ RETURN r"""
 
946
  params = {
947
  "src_label": PGGraphStorage._encode_graph_label(source_node_label),
948
  "tgt_label": PGGraphStorage._encode_graph_label(target_node_label),
949
  "properties": PGGraphStorage._format_properties(edge_properties),
950
  }
951
+ # logger.info(f"-- inserting edge after formatted: {params}")
952
  try:
953
+ await self._query(query, readonly=False, upsert_edge=True, **params)
954
  logger.debug(
955
  "Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
956
  source_node_label,
 
1105
  updatetime = CURRENT_TIMESTAMP
1106
  """,
1107
  "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
1108
+ VALUES ($1, $2, $3, $4, $5)
1109
  ON CONFLICT (workspace,id) DO UPDATE
1110
  SET entity_name=EXCLUDED.entity_name,
1111
  content=EXCLUDED.content,
lightrag/kg/postgres_impl_test.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import asyncpg
3
+ import sys, os
4
+
5
+ import psycopg
6
+ from psycopg_pool import AsyncConnectionPool
7
+ from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
8
+
9
+ DB="rag"
10
+ USER="rag"
11
+ PASSWORD="rag"
12
+ HOST="localhost"
13
+ PORT="15432"
14
+ os.environ["AGE_GRAPH_NAME"] = "dickens"
15
+
16
+ if sys.platform.startswith("win"):
17
+ import asyncio.windows_events
18
+ asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
19
+
20
+ async def get_pool():
21
+ return await asyncpg.create_pool(
22
+ f"postgres://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}",
23
+ min_size=10, # 连接池初始化时默认的最小连接数, 默认为1 0
24
+ max_size=10, # 连接池的最大连接数, 默认为 10
25
+ max_queries=5000, # 每个链接最大查询数量, 超过了就换新的连接, 默认 5000
26
+ # 最大不活跃时间, 默认 300.0, 超过这个时间的连接就会被关闭, 传入 0 的话则永不关闭
27
+ max_inactive_connection_lifetime=300.0
28
+ )
29
+
30
+ async def main1():
31
+ connection_string = f"dbname='{DB}' user='{USER}' password='{PASSWORD}' host='{HOST}' port={PORT}"
32
+ pool = AsyncConnectionPool(connection_string, open=False)
33
+ await pool.open()
34
+
35
+ try:
36
+ conn = await pool.getconn(timeout=10)
37
+ async with conn.cursor() as curs:
38
+ try:
39
+ await curs.execute('SET search_path = ag_catalog, "$user", public')
40
+ await curs.execute(f"SELECT create_graph('dickens-2')")
41
+ await conn.commit()
42
+ print("create_graph success")
43
+ except (
44
+ psycopg.errors.InvalidSchemaName,
45
+ psycopg.errors.UniqueViolation,
46
+ ):
47
+ print("create_graph already exists")
48
+ await conn.rollback()
49
+ finally:
50
+ pass
51
+
52
+ db = PostgreSQLDB(
53
+ config={
54
+ "host": "localhost",
55
+ "port": 15432,
56
+ "user": "rag",
57
+ "password": "rag",
58
+ "database": "rag",
59
+ }
60
+ )
61
+
62
+ async def query_with_age():
63
+ await db.initdb()
64
+ graph = PGGraphStorage(
65
+ namespace="chunk_entity_relation",
66
+ global_config={},
67
+ embedding_func=None,
68
+ )
69
+ graph.db = db
70
+ res = await graph.get_node('"CHRISTMAS-TIME"')
71
+ print("Node is: ", res)
72
+
73
+ async def create_edge_with_age():
74
+ await db.initdb()
75
+ graph = PGGraphStorage(
76
+ namespace="chunk_entity_relation",
77
+ global_config={},
78
+ embedding_func=None,
79
+ )
80
+ graph.db = db
81
+ await graph.upsert_node('"THE CRATCHITS"', {"hello": "world"})
82
+ await graph.upsert_node('"THE GIRLS"', {"world": "hello"})
83
+ await graph.upsert_edge(
84
+ '"THE CRATCHITS"',
85
+ '"THE GIRLS"',
86
+ edge_data={
87
+ "weight": 7.0,
88
+ "description": '"The girls are part of the Cratchit family, contributing to their collective efforts and shared experiences.',
89
+ "keywords": '"family, collective effort"',
90
+ "source_id": "chunk-1d4b58de5429cd1261370c231c8673e8",
91
+ },
92
+ )
93
+ res = await graph.get_edge('THE CRATCHITS', '"THE GIRLS"')
94
+ print("Edge is: ", res)
95
+
96
+
97
+ async def main():
98
+ pool = await get_pool()
99
+ # 如果还有其它什么特殊参数,也可以直接往里面传递,因为设置了 **connect_kwargs
100
+ # 专门用来设置一些数据库独有的某些属性
101
+ # 从池子中取出一个连接
102
+ sql = r"SELECT * FROM ag_catalog.cypher('dickens', $$ MATCH (n:帅哥) RETURN n $$) AS (n ag_catalog.agtype)"
103
+ # cypher = "MATCH (n:how_are_you_doing) RETURN n"
104
+ async with pool.acquire() as conn:
105
+ try:
106
+ await conn.execute("""SET search_path = ag_catalog, "$user", public;select create_graph('dickens')""")
107
+ except asyncpg.exceptions.InvalidSchemaNameError:
108
+ print("create_graph already exists")
109
+ # stmt = await conn.prepare(sql)
110
+ row = await conn.fetch(sql)
111
+ print("row is: ", row)
112
+
113
+ # 解决办法就是起一个别名
114
+ row = await conn.fetchrow("select '100'::int + 200 as result")
115
+ print(row) # <Record result=300>
116
+ # 我们的连接是从池子里面取出的,上下文结束之后会自动放回到到池子里面
117
+
118
+
119
+ if __name__ == '__main__':
120
+ asyncio.run(query_with_age())
121
+
122
+