feat: Add SSL support for PostgreSQL database connections
Browse files- Add SSL configuration options (ssl_mode, ssl_cert, ssl_key, ssl_root_cert, ssl_crl)
- Support all PostgreSQL SSL modes (disable, allow, prefer, require, verify-ca, verify-full)
- Add SSL context creation with certificate validation
- Update initdb() method to handle SSL connection parameters
- Add SSL environment variables to env.example
- Maintain backward compatibility with existing non-SSL configurations
- env.example +7 -0
- lightrag/kg/postgres_impl.py +126 -10
env.example
CHANGED
@@ -189,6 +189,13 @@ POSTGRES_DATABASE=your_database
|
|
189 |
POSTGRES_MAX_CONNECTIONS=12
|
190 |
# POSTGRES_WORKSPACE=forced_workspace_name
|
191 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
192 |
### Neo4j Configuration
|
193 |
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
194 |
NEO4J_USERNAME=neo4j
|
|
|
189 |
POSTGRES_MAX_CONNECTIONS=12
|
190 |
# POSTGRES_WORKSPACE=forced_workspace_name
|
191 |
|
192 |
+
### PostgreSQL SSL Configuration (Optional)
|
193 |
+
# POSTGRES_SSL_MODE=require
|
194 |
+
# POSTGRES_SSL_CERT=/path/to/client-cert.pem
|
195 |
+
# POSTGRES_SSL_KEY=/path/to/client-key.pem
|
196 |
+
# POSTGRES_SSL_ROOT_CERT=/path/to/ca-cert.pem
|
197 |
+
# POSTGRES_SSL_CRL=/path/to/crl.pem
|
198 |
+
|
199 |
### Neo4j Configuration
|
200 |
NEO4J_URI=neo4j+s://xxxxxxxx.databases.neo4j.io
|
201 |
NEO4J_USERNAME=neo4j
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -8,6 +8,7 @@ from dataclasses import dataclass, field
|
|
8 |
from typing import Any, Union, final
|
9 |
import numpy as np
|
10 |
import configparser
|
|
|
11 |
|
12 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
13 |
|
@@ -58,27 +59,121 @@ class PostgreSQLDB:
|
|
58 |
self.increment = 1
|
59 |
self.pool: Pool | None = None
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
if self.user is None or self.password is None or self.database is None:
|
62 |
raise ValueError("Missing database user, password, or database")
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
async def initdb(self):
|
65 |
try:
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
|
76 |
# Ensure VECTOR extension is available
|
77 |
async with self.pool.acquire() as connection:
|
78 |
await self.configure_vector_extension(connection)
|
79 |
|
|
|
80 |
logger.info(
|
81 |
-
f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database}"
|
82 |
)
|
83 |
except Exception as e:
|
84 |
logger.error(
|
@@ -809,6 +904,27 @@ class ClientManager:
|
|
809 |
"POSTGRES_MAX_CONNECTIONS",
|
810 |
config.get("postgres", "max_connections", fallback=20),
|
811 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
812 |
}
|
813 |
|
814 |
@classmethod
|
|
|
8 |
from typing import Any, Union, final
|
9 |
import numpy as np
|
10 |
import configparser
|
11 |
+
import ssl
|
12 |
|
13 |
from lightrag.types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
|
14 |
|
|
|
59 |
self.increment = 1
|
60 |
self.pool: Pool | None = None
|
61 |
|
62 |
+
# SSL configuration
|
63 |
+
self.ssl_mode = config.get("ssl_mode")
|
64 |
+
self.ssl_cert = config.get("ssl_cert")
|
65 |
+
self.ssl_key = config.get("ssl_key")
|
66 |
+
self.ssl_root_cert = config.get("ssl_root_cert")
|
67 |
+
self.ssl_crl = config.get("ssl_crl")
|
68 |
+
|
69 |
if self.user is None or self.password is None or self.database is None:
|
70 |
raise ValueError("Missing database user, password, or database")
|
71 |
|
72 |
+
def _create_ssl_context(self) -> ssl.SSLContext | None:
|
73 |
+
"""Create SSL context based on configuration parameters."""
|
74 |
+
if not self.ssl_mode:
|
75 |
+
return None
|
76 |
+
|
77 |
+
ssl_mode = self.ssl_mode.lower()
|
78 |
+
|
79 |
+
# For simple modes that don't require custom context
|
80 |
+
if ssl_mode in ["disable", "allow", "prefer", "require"]:
|
81 |
+
if ssl_mode == "disable":
|
82 |
+
return None
|
83 |
+
elif ssl_mode in ["require", "prefer"]:
|
84 |
+
# Return None for simple SSL requirement, handled in initdb
|
85 |
+
return None
|
86 |
+
|
87 |
+
# For modes that require certificate verification
|
88 |
+
if ssl_mode in ["verify-ca", "verify-full"]:
|
89 |
+
try:
|
90 |
+
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
|
91 |
+
|
92 |
+
# Configure certificate verification
|
93 |
+
if ssl_mode == "verify-ca":
|
94 |
+
context.check_hostname = False
|
95 |
+
elif ssl_mode == "verify-full":
|
96 |
+
context.check_hostname = True
|
97 |
+
|
98 |
+
# Load root certificate if provided
|
99 |
+
if self.ssl_root_cert:
|
100 |
+
if os.path.exists(self.ssl_root_cert):
|
101 |
+
context.load_verify_locations(cafile=self.ssl_root_cert)
|
102 |
+
logger.info(
|
103 |
+
f"PostgreSQL, Loaded SSL root certificate: {self.ssl_root_cert}"
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
logger.warning(
|
107 |
+
f"PostgreSQL, SSL root certificate file not found: {self.ssl_root_cert}"
|
108 |
+
)
|
109 |
+
|
110 |
+
# Load client certificate and key if provided
|
111 |
+
if self.ssl_cert and self.ssl_key:
|
112 |
+
if os.path.exists(self.ssl_cert) and os.path.exists(self.ssl_key):
|
113 |
+
context.load_cert_chain(self.ssl_cert, self.ssl_key)
|
114 |
+
logger.info(
|
115 |
+
f"PostgreSQL, Loaded SSL client certificate: {self.ssl_cert}"
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
logger.warning(
|
119 |
+
"PostgreSQL, SSL client certificate or key file not found"
|
120 |
+
)
|
121 |
+
|
122 |
+
# Load certificate revocation list if provided
|
123 |
+
if self.ssl_crl:
|
124 |
+
if os.path.exists(self.ssl_crl):
|
125 |
+
context.load_verify_locations(crlfile=self.ssl_crl)
|
126 |
+
logger.info(f"PostgreSQL, Loaded SSL CRL: {self.ssl_crl}")
|
127 |
+
else:
|
128 |
+
logger.warning(
|
129 |
+
f"PostgreSQL, SSL CRL file not found: {self.ssl_crl}"
|
130 |
+
)
|
131 |
+
|
132 |
+
return context
|
133 |
+
|
134 |
+
except Exception as e:
|
135 |
+
logger.error(f"PostgreSQL, Failed to create SSL context: {e}")
|
136 |
+
raise ValueError(f"SSL configuration error: {e}")
|
137 |
+
|
138 |
+
# Unknown SSL mode
|
139 |
+
logger.warning(f"PostgreSQL, Unknown SSL mode: {ssl_mode}, SSL disabled")
|
140 |
+
return None
|
141 |
+
|
142 |
async def initdb(self):
|
143 |
try:
|
144 |
+
# Prepare connection parameters
|
145 |
+
connection_params = {
|
146 |
+
"user": self.user,
|
147 |
+
"password": self.password,
|
148 |
+
"database": self.database,
|
149 |
+
"host": self.host,
|
150 |
+
"port": self.port,
|
151 |
+
"min_size": 1,
|
152 |
+
"max_size": self.max,
|
153 |
+
}
|
154 |
+
|
155 |
+
# Add SSL configuration if provided
|
156 |
+
ssl_context = self._create_ssl_context()
|
157 |
+
if ssl_context is not None:
|
158 |
+
connection_params["ssl"] = ssl_context
|
159 |
+
logger.info("PostgreSQL, SSL configuration applied")
|
160 |
+
elif self.ssl_mode:
|
161 |
+
# Handle simple SSL modes without custom context
|
162 |
+
if self.ssl_mode.lower() in ["require", "prefer"]:
|
163 |
+
connection_params["ssl"] = True
|
164 |
+
elif self.ssl_mode.lower() == "disable":
|
165 |
+
connection_params["ssl"] = False
|
166 |
+
logger.info(f"PostgreSQL, SSL mode set to: {self.ssl_mode}")
|
167 |
+
|
168 |
+
self.pool = await asyncpg.create_pool(**connection_params) # type: ignore
|
169 |
|
170 |
# Ensure VECTOR extension is available
|
171 |
async with self.pool.acquire() as connection:
|
172 |
await self.configure_vector_extension(connection)
|
173 |
|
174 |
+
ssl_status = "with SSL" if connection_params.get("ssl") else "without SSL"
|
175 |
logger.info(
|
176 |
+
f"PostgreSQL, Connected to database at {self.host}:{self.port}/{self.database} {ssl_status}"
|
177 |
)
|
178 |
except Exception as e:
|
179 |
logger.error(
|
|
|
904 |
"POSTGRES_MAX_CONNECTIONS",
|
905 |
config.get("postgres", "max_connections", fallback=20),
|
906 |
),
|
907 |
+
# SSL configuration
|
908 |
+
"ssl_mode": os.environ.get(
|
909 |
+
"POSTGRES_SSL_MODE",
|
910 |
+
config.get("postgres", "ssl_mode", fallback=None),
|
911 |
+
),
|
912 |
+
"ssl_cert": os.environ.get(
|
913 |
+
"POSTGRES_SSL_CERT",
|
914 |
+
config.get("postgres", "ssl_cert", fallback=None),
|
915 |
+
),
|
916 |
+
"ssl_key": os.environ.get(
|
917 |
+
"POSTGRES_SSL_KEY",
|
918 |
+
config.get("postgres", "ssl_key", fallback=None),
|
919 |
+
),
|
920 |
+
"ssl_root_cert": os.environ.get(
|
921 |
+
"POSTGRES_SSL_ROOT_CERT",
|
922 |
+
config.get("postgres", "ssl_root_cert", fallback=None),
|
923 |
+
),
|
924 |
+
"ssl_crl": os.environ.get(
|
925 |
+
"POSTGRES_SSL_CRL",
|
926 |
+
config.get("postgres", "ssl_crl", fallback=None),
|
927 |
+
),
|
928 |
}
|
929 |
|
930 |
@classmethod
|