YanSte commited on
Commit
f8dec77
·
unverified ·
2 Parent(s): 82dbfbd 97a0a3f

Merge pull request #869 from YanSte/back-age

Browse files
Files changed (1) hide show
  1. lightrag/kg/postgres_impl.py +41 -21
lightrag/kg/postgres_impl.py CHANGED
@@ -86,29 +86,40 @@ class PostgreSQLDB:
86
  )
87
  raise
88
 
89
- async def check_graph_requirement(self, graph_name: str):
90
- async with self.pool.acquire() as connection: # type: ignore
91
- try:
92
- await connection.execute(
93
- 'SET search_path = ag_catalog, "$user", public'
94
- ) # type: ignore
95
- await connection.execute(f"select create_graph('{graph_name}')") # type: ignore
96
- except (
97
- asyncpg.exceptions.InvalidSchemaNameError,
98
- asyncpg.exceptions.UniqueViolationError,
99
- ):
100
- pass
 
 
 
 
 
 
 
 
 
 
101
 
102
  async def check_tables(self):
103
  for k, v in TABLES.items():
104
  try:
105
  await self.query(f"SELECT 1 FROM {k} LIMIT 1")
106
- except Exception as e:
107
- logger.error(f"PostgreSQL database error: {e}")
108
  try:
109
  logger.info(f"PostgreSQL, Try Creating table {k} in database")
110
  await self.execute(v["ddl"])
111
- logger.info(f"PostgreSQL, Created table {k} in PostgreSQL database")
 
 
112
  except Exception as e:
113
  logger.error(
114
  f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
@@ -120,8 +131,15 @@ class PostgreSQLDB:
120
  sql: str,
121
  params: dict[str, Any] | None = None,
122
  multirows: bool = False,
 
 
123
  ) -> dict[str, Any] | None | list[dict[str, Any]]:
124
  async with self.pool.acquire() as connection: # type: ignore
 
 
 
 
 
125
  try:
126
  if params:
127
  rows = await connection.fetch(sql, *params.values())
@@ -142,9 +160,7 @@ class PostgreSQLDB:
142
  data = None
143
  return data
144
  except Exception as e:
145
- logger.error(
146
- f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
147
- )
148
  raise
149
 
150
  async def execute(
@@ -152,6 +168,8 @@ class PostgreSQLDB:
152
  sql: str,
153
  data: dict[str, Any] | None = None,
154
  upsert: bool = False,
 
 
155
  ):
156
  try:
157
  async with self.pool.acquire() as connection: # type: ignore
@@ -656,9 +674,6 @@ class PGGraphStorage(BaseGraphStorage):
656
  async def initialize(self):
657
  if self.db is None:
658
  self.db = await ClientManager.get_client()
659
- # `check_graph_requirement` is required to be executed after `get_client`
660
- # to ensure the graph is created before any query is executed.
661
- await self.db.check_graph_requirement(self.graph_name)
662
 
663
  async def finalize(self):
664
  if self.db is not None:
@@ -829,12 +844,17 @@ class PGGraphStorage(BaseGraphStorage):
829
  data = await self.db.query(
830
  query,
831
  multirows=True,
 
 
832
  )
833
  else:
834
  data = await self.db.execute(
835
  query,
836
  upsert=upsert,
 
 
837
  )
 
838
  except Exception as e:
839
  raise PGGraphQueryException(
840
  {
 
86
  )
87
  raise
88
 
89
+ @staticmethod
90
+ async def configure_age(connection: asyncpg.Connection, graph_name: str) -> None:
91
+ """Set the Apache AGE environment and creates a graph if it does not exist.
92
+
93
+ This method:
94
+ - Sets the PostgreSQL `search_path` to include `ag_catalog`, ensuring that Apache AGE functions can be used without specifying the schema.
95
+ - Attempts to create a new graph with the provided `graph_name` if it does not already exist.
96
+ - Silently ignores errors related to the graph already existing.
97
+
98
+ """
99
+ try:
100
+ await connection.execute( # type: ignore
101
+ 'SET search_path = ag_catalog, "$user", public'
102
+ )
103
+ await connection.execute( # type: ignore
104
+ f"select create_graph('{graph_name}')"
105
+ )
106
+ except (
107
+ asyncpg.exceptions.InvalidSchemaNameError,
108
+ asyncpg.exceptions.UniqueViolationError,
109
+ ):
110
+ pass
111
 
112
  async def check_tables(self):
113
  for k, v in TABLES.items():
114
  try:
115
  await self.query(f"SELECT 1 FROM {k} LIMIT 1")
116
+ except Exception:
 
117
  try:
118
  logger.info(f"PostgreSQL, Try Creating table {k} in database")
119
  await self.execute(v["ddl"])
120
+ logger.info(
121
+ f"PostgreSQL, Creation success table {k} in PostgreSQL database"
122
+ )
123
  except Exception as e:
124
  logger.error(
125
  f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
 
131
  sql: str,
132
  params: dict[str, Any] | None = None,
133
  multirows: bool = False,
134
+ with_age: bool = False,
135
+ graph_name: str | None = None,
136
  ) -> dict[str, Any] | None | list[dict[str, Any]]:
137
  async with self.pool.acquire() as connection: # type: ignore
138
+ if with_age and graph_name:
139
+ await self.configure_age(connection, graph_name) # type: ignore
140
+ elif with_age and not graph_name:
141
+ raise ValueError("Graph name is required when with_age is True")
142
+
143
  try:
144
  if params:
145
  rows = await connection.fetch(sql, *params.values())
 
160
  data = None
161
  return data
162
  except Exception as e:
163
+ logger.error(f"PostgreSQL database, error:{e}")
 
 
164
  raise
165
 
166
  async def execute(
 
168
  sql: str,
169
  data: dict[str, Any] | None = None,
170
  upsert: bool = False,
171
+ with_age: bool = False,
172
+ graph_name: str | None = None,
173
  ):
174
  try:
175
  async with self.pool.acquire() as connection: # type: ignore
 
674
  async def initialize(self):
675
  if self.db is None:
676
  self.db = await ClientManager.get_client()
 
 
 
677
 
678
  async def finalize(self):
679
  if self.db is not None:
 
844
  data = await self.db.query(
845
  query,
846
  multirows=True,
847
+ with_age=True,
848
+ graph_name=self.graph_name,
849
  )
850
  else:
851
  data = await self.db.execute(
852
  query,
853
  upsert=upsert,
854
+ with_age=True,
855
+ graph_name=self.graph_name,
856
  )
857
+
858
  except Exception as e:
859
  raise PGGraphQueryException(
860
  {