Merge pull request #869 from YanSte/back-age
Browse files- lightrag/kg/postgres_impl.py +41 -21
lightrag/kg/postgres_impl.py
CHANGED
@@ -86,29 +86,40 @@ class PostgreSQLDB:
|
|
86 |
)
|
87 |
raise
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
async def check_tables(self):
|
103 |
for k, v in TABLES.items():
|
104 |
try:
|
105 |
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
|
106 |
-
except Exception
|
107 |
-
logger.error(f"PostgreSQL database error: {e}")
|
108 |
try:
|
109 |
logger.info(f"PostgreSQL, Try Creating table {k} in database")
|
110 |
await self.execute(v["ddl"])
|
111 |
-
logger.info(
|
|
|
|
|
112 |
except Exception as e:
|
113 |
logger.error(
|
114 |
f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
|
@@ -120,8 +131,15 @@ class PostgreSQLDB:
|
|
120 |
sql: str,
|
121 |
params: dict[str, Any] | None = None,
|
122 |
multirows: bool = False,
|
|
|
|
|
123 |
) -> dict[str, Any] | None | list[dict[str, Any]]:
|
124 |
async with self.pool.acquire() as connection: # type: ignore
|
|
|
|
|
|
|
|
|
|
|
125 |
try:
|
126 |
if params:
|
127 |
rows = await connection.fetch(sql, *params.values())
|
@@ -142,9 +160,7 @@ class PostgreSQLDB:
|
|
142 |
data = None
|
143 |
return data
|
144 |
except Exception as e:
|
145 |
-
logger.error(
|
146 |
-
f"PostgreSQL database,\nsql:{sql},\nparams:{params},\nerror:{e}"
|
147 |
-
)
|
148 |
raise
|
149 |
|
150 |
async def execute(
|
@@ -152,6 +168,8 @@ class PostgreSQLDB:
|
|
152 |
sql: str,
|
153 |
data: dict[str, Any] | None = None,
|
154 |
upsert: bool = False,
|
|
|
|
|
155 |
):
|
156 |
try:
|
157 |
async with self.pool.acquire() as connection: # type: ignore
|
@@ -656,9 +674,6 @@ class PGGraphStorage(BaseGraphStorage):
|
|
656 |
async def initialize(self):
|
657 |
if self.db is None:
|
658 |
self.db = await ClientManager.get_client()
|
659 |
-
# `check_graph_requirement` is required to be executed after `get_client`
|
660 |
-
# to ensure the graph is created before any query is executed.
|
661 |
-
await self.db.check_graph_requirement(self.graph_name)
|
662 |
|
663 |
async def finalize(self):
|
664 |
if self.db is not None:
|
@@ -829,12 +844,17 @@ class PGGraphStorage(BaseGraphStorage):
|
|
829 |
data = await self.db.query(
|
830 |
query,
|
831 |
multirows=True,
|
|
|
|
|
832 |
)
|
833 |
else:
|
834 |
data = await self.db.execute(
|
835 |
query,
|
836 |
upsert=upsert,
|
|
|
|
|
837 |
)
|
|
|
838 |
except Exception as e:
|
839 |
raise PGGraphQueryException(
|
840 |
{
|
|
|
86 |
)
|
87 |
raise
|
88 |
|
89 |
+
@staticmethod
|
90 |
+
async def configure_age(connection: asyncpg.Connection, graph_name: str) -> None:
|
91 |
+
"""Set the Apache AGE environment and creates a graph if it does not exist.
|
92 |
+
|
93 |
+
This method:
|
94 |
+
- Sets the PostgreSQL `search_path` to include `ag_catalog`, ensuring that Apache AGE functions can be used without specifying the schema.
|
95 |
+
- Attempts to create a new graph with the provided `graph_name` if it does not already exist.
|
96 |
+
- Silently ignores errors related to the graph already existing.
|
97 |
+
|
98 |
+
"""
|
99 |
+
try:
|
100 |
+
await connection.execute( # type: ignore
|
101 |
+
'SET search_path = ag_catalog, "$user", public'
|
102 |
+
)
|
103 |
+
await connection.execute( # type: ignore
|
104 |
+
f"select create_graph('{graph_name}')"
|
105 |
+
)
|
106 |
+
except (
|
107 |
+
asyncpg.exceptions.InvalidSchemaNameError,
|
108 |
+
asyncpg.exceptions.UniqueViolationError,
|
109 |
+
):
|
110 |
+
pass
|
111 |
|
112 |
async def check_tables(self):
|
113 |
for k, v in TABLES.items():
|
114 |
try:
|
115 |
await self.query(f"SELECT 1 FROM {k} LIMIT 1")
|
116 |
+
except Exception:
|
|
|
117 |
try:
|
118 |
logger.info(f"PostgreSQL, Try Creating table {k} in database")
|
119 |
await self.execute(v["ddl"])
|
120 |
+
logger.info(
|
121 |
+
f"PostgreSQL, Creation success table {k} in PostgreSQL database"
|
122 |
+
)
|
123 |
except Exception as e:
|
124 |
logger.error(
|
125 |
f"PostgreSQL, Failed to create table {k} in database, Please verify the connection with PostgreSQL database, Got: {e}"
|
|
|
131 |
sql: str,
|
132 |
params: dict[str, Any] | None = None,
|
133 |
multirows: bool = False,
|
134 |
+
with_age: bool = False,
|
135 |
+
graph_name: str | None = None,
|
136 |
) -> dict[str, Any] | None | list[dict[str, Any]]:
|
137 |
async with self.pool.acquire() as connection: # type: ignore
|
138 |
+
if with_age and graph_name:
|
139 |
+
await self.configure_age(connection, graph_name) # type: ignore
|
140 |
+
elif with_age and not graph_name:
|
141 |
+
raise ValueError("Graph name is required when with_age is True")
|
142 |
+
|
143 |
try:
|
144 |
if params:
|
145 |
rows = await connection.fetch(sql, *params.values())
|
|
|
160 |
data = None
|
161 |
return data
|
162 |
except Exception as e:
|
163 |
+
logger.error(f"PostgreSQL database, error:{e}")
|
|
|
|
|
164 |
raise
|
165 |
|
166 |
async def execute(
|
|
|
168 |
sql: str,
|
169 |
data: dict[str, Any] | None = None,
|
170 |
upsert: bool = False,
|
171 |
+
with_age: bool = False,
|
172 |
+
graph_name: str | None = None,
|
173 |
):
|
174 |
try:
|
175 |
async with self.pool.acquire() as connection: # type: ignore
|
|
|
674 |
async def initialize(self):
|
675 |
if self.db is None:
|
676 |
self.db = await ClientManager.get_client()
|
|
|
|
|
|
|
677 |
|
678 |
async def finalize(self):
|
679 |
if self.db is not None:
|
|
|
844 |
data = await self.db.query(
|
845 |
query,
|
846 |
multirows=True,
|
847 |
+
with_age=True,
|
848 |
+
graph_name=self.graph_name,
|
849 |
)
|
850 |
else:
|
851 |
data = await self.db.execute(
|
852 |
query,
|
853 |
upsert=upsert,
|
854 |
+
with_age=True,
|
855 |
+
graph_name=self.graph_name,
|
856 |
)
|
857 |
+
|
858 |
except Exception as e:
|
859 |
raise PGGraphQueryException(
|
860 |
{
|