ArnoChen commited on
Commit
5f2a32b
·
1 Parent(s): 6c5270a

refactor database connection management and improve storage lifecycle handling

Browse files
lightrag/api/lightrag_server.py CHANGED
@@ -15,11 +15,6 @@ import logging
15
  import argparse
16
  from typing import List, Any, Literal, Optional, Dict
17
  from pydantic import BaseModel, Field, field_validator
18
- from lightrag import LightRAG, QueryParam
19
- from lightrag.base import DocProcessingStatus, DocStatus
20
- from lightrag.types import GPTKeywordExtractionFormat
21
- from lightrag.api import __api_version__
22
- from lightrag.utils import EmbeddingFunc
23
  from pathlib import Path
24
  import shutil
25
  import aiofiles
@@ -36,39 +31,13 @@ import configparser
36
  import traceback
37
  from datetime import datetime
38
 
 
 
 
 
 
39
  from lightrag.utils import logger
40
- from .ollama_api import (
41
- OllamaAPI,
42
- )
43
- from .ollama_api import ollama_server_infos
44
-
45
-
46
- def get_db_type_from_storage_class(class_name: str) -> str | None:
47
- """Determine database type based on storage class name"""
48
- if class_name.startswith("PG"):
49
- return "postgres"
50
- elif class_name.startswith("Oracle"):
51
- return "oracle"
52
- elif class_name.startswith("TiDB"):
53
- return "tidb"
54
- return None
55
-
56
-
57
- def import_db_module(db_type: str):
58
- """Dynamically import database module"""
59
- if db_type == "postgres":
60
- from ..kg.postgres_impl import PostgreSQLDB
61
-
62
- return PostgreSQLDB
63
- elif db_type == "oracle":
64
- from ..kg.oracle_impl import OracleDB
65
-
66
- return OracleDB
67
- elif db_type == "tidb":
68
- from ..kg.tidb_impl import TiDB
69
-
70
- return TiDB
71
- return None
72
 
73
 
74
  # Load environment variables
@@ -929,52 +898,12 @@ def create_app(args):
929
  @asynccontextmanager
930
  async def lifespan(app: FastAPI):
931
  """Lifespan context manager for startup and shutdown events"""
932
- # Initialize database connections
933
- db_instances = {}
934
  # Store background tasks
935
  app.state.background_tasks = set()
936
 
937
  try:
938
- # Check which database types are used
939
- db_types = set()
940
- for storage_name, storage_instance in storage_instances:
941
- db_type = get_db_type_from_storage_class(
942
- storage_instance.__class__.__name__
943
- )
944
- if db_type:
945
- db_types.add(db_type)
946
-
947
- # Import and initialize databases as needed
948
- for db_type in db_types:
949
- if db_type == "postgres":
950
- DB = import_db_module("postgres")
951
- db = DB(_get_postgres_config())
952
- await db.initdb()
953
- await db.check_tables()
954
- db_instances["postgres"] = db
955
- elif db_type == "oracle":
956
- DB = import_db_module("oracle")
957
- db = DB(_get_oracle_config())
958
- await db.check_tables()
959
- db_instances["oracle"] = db
960
- elif db_type == "tidb":
961
- DB = import_db_module("tidb")
962
- db = DB(_get_tidb_config())
963
- await db.check_tables()
964
- db_instances["tidb"] = db
965
-
966
- # Inject database instances into storage classes
967
- for storage_name, storage_instance in storage_instances:
968
- db_type = get_db_type_from_storage_class(
969
- storage_instance.__class__.__name__
970
- )
971
- if db_type:
972
- if db_type not in db_instances:
973
- error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized"
974
- logger.error(error_msg)
975
- raise RuntimeError(error_msg)
976
- storage_instance.db = db_instances[db_type]
977
- logger.info(f"Injected {db_type} db to {storage_name}")
978
 
979
  # Auto scan documents if enabled
980
  if args.auto_scan_at_startup:
@@ -1000,17 +929,7 @@ def create_app(args):
1000
 
1001
  finally:
1002
  # Clean up database connections
1003
- for db_type, db in db_instances.items():
1004
- if hasattr(db, "pool"):
1005
- await db.pool.close()
1006
- # Use more accurate database name display
1007
- db_names = {
1008
- "postgres": "PostgreSQL",
1009
- "oracle": "Oracle",
1010
- "tidb": "TiDB",
1011
- }
1012
- db_name = db_names.get(db_type, db_type)
1013
- logger.info(f"Closed {db_name} database connection pool")
1014
 
1015
  # Initialize FastAPI
1016
  app = FastAPI(
@@ -1042,92 +961,6 @@ def create_app(args):
1042
  allow_headers=["*"],
1043
  )
1044
 
1045
- # Database configuration functions
1046
- def _get_postgres_config():
1047
- return {
1048
- "host": os.environ.get(
1049
- "POSTGRES_HOST",
1050
- config.get("postgres", "host", fallback="localhost"),
1051
- ),
1052
- "port": os.environ.get(
1053
- "POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
1054
- ),
1055
- "user": os.environ.get(
1056
- "POSTGRES_USER", config.get("postgres", "user", fallback=None)
1057
- ),
1058
- "password": os.environ.get(
1059
- "POSTGRES_PASSWORD",
1060
- config.get("postgres", "password", fallback=None),
1061
- ),
1062
- "database": os.environ.get(
1063
- "POSTGRES_DATABASE",
1064
- config.get("postgres", "database", fallback=None),
1065
- ),
1066
- "workspace": os.environ.get(
1067
- "POSTGRES_WORKSPACE",
1068
- config.get("postgres", "workspace", fallback="default"),
1069
- ),
1070
- }
1071
-
1072
- def _get_oracle_config():
1073
- return {
1074
- "user": os.environ.get(
1075
- "ORACLE_USER",
1076
- config.get("oracle", "user", fallback=None),
1077
- ),
1078
- "password": os.environ.get(
1079
- "ORACLE_PASSWORD",
1080
- config.get("oracle", "password", fallback=None),
1081
- ),
1082
- "dsn": os.environ.get(
1083
- "ORACLE_DSN",
1084
- config.get("oracle", "dsn", fallback=None),
1085
- ),
1086
- "config_dir": os.environ.get(
1087
- "ORACLE_CONFIG_DIR",
1088
- config.get("oracle", "config_dir", fallback=None),
1089
- ),
1090
- "wallet_location": os.environ.get(
1091
- "ORACLE_WALLET_LOCATION",
1092
- config.get("oracle", "wallet_location", fallback=None),
1093
- ),
1094
- "wallet_password": os.environ.get(
1095
- "ORACLE_WALLET_PASSWORD",
1096
- config.get("oracle", "wallet_password", fallback=None),
1097
- ),
1098
- "workspace": os.environ.get(
1099
- "ORACLE_WORKSPACE",
1100
- config.get("oracle", "workspace", fallback="default"),
1101
- ),
1102
- }
1103
-
1104
- def _get_tidb_config():
1105
- return {
1106
- "host": os.environ.get(
1107
- "TIDB_HOST",
1108
- config.get("tidb", "host", fallback="localhost"),
1109
- ),
1110
- "port": os.environ.get(
1111
- "TIDB_PORT", config.get("tidb", "port", fallback=4000)
1112
- ),
1113
- "user": os.environ.get(
1114
- "TIDB_USER",
1115
- config.get("tidb", "user", fallback=None),
1116
- ),
1117
- "password": os.environ.get(
1118
- "TIDB_PASSWORD",
1119
- config.get("tidb", "password", fallback=None),
1120
- ),
1121
- "database": os.environ.get(
1122
- "TIDB_DATABASE",
1123
- config.get("tidb", "database", fallback=None),
1124
- ),
1125
- "workspace": os.environ.get(
1126
- "TIDB_WORKSPACE",
1127
- config.get("tidb", "workspace", fallback="default"),
1128
- ),
1129
- }
1130
-
1131
  # Create the optional API key dependency
1132
  optional_api_key = get_api_key_dependency(api_key)
1133
 
@@ -1262,6 +1095,7 @@ def create_app(args):
1262
  },
1263
  log_level=args.log_level,
1264
  namespace_prefix=args.namespace_prefix,
 
1265
  )
1266
  else:
1267
  rag = LightRAG(
@@ -1293,20 +1127,9 @@ def create_app(args):
1293
  },
1294
  log_level=args.log_level,
1295
  namespace_prefix=args.namespace_prefix,
 
1296
  )
1297
 
1298
- # Collect all storage instances
1299
- storage_instances = [
1300
- ("full_docs", rag.full_docs),
1301
- ("text_chunks", rag.text_chunks),
1302
- ("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
1303
- ("entities_vdb", rag.entities_vdb),
1304
- ("relationships_vdb", rag.relationships_vdb),
1305
- ("chunks_vdb", rag.chunks_vdb),
1306
- ("doc_status", rag.doc_status),
1307
- ("llm_response_cache", rag.llm_response_cache),
1308
- ]
1309
-
1310
  async def pipeline_enqueue_file(file_path: Path) -> bool:
1311
  """Add a file to the queue for processing
1312
 
 
15
  import argparse
16
  from typing import List, Any, Literal, Optional, Dict
17
  from pydantic import BaseModel, Field, field_validator
 
 
 
 
 
18
  from pathlib import Path
19
  import shutil
20
  import aiofiles
 
31
  import traceback
32
  from datetime import datetime
33
 
34
+ from lightrag import LightRAG, QueryParam
35
+ from lightrag.base import DocProcessingStatus, DocStatus
36
+ from lightrag.types import GPTKeywordExtractionFormat
37
+ from lightrag.api import __api_version__
38
+ from lightrag.utils import EmbeddingFunc
39
  from lightrag.utils import logger
40
+ from .ollama_api import OllamaAPI, ollama_server_infos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  # Load environment variables
 
898
  @asynccontextmanager
899
  async def lifespan(app: FastAPI):
900
  """Lifespan context manager for startup and shutdown events"""
 
 
901
  # Store background tasks
902
  app.state.background_tasks = set()
903
 
904
  try:
905
+ # Initialize database connections
906
+ await rag.initialize_storages()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
907
 
908
  # Auto scan documents if enabled
909
  if args.auto_scan_at_startup:
 
929
 
930
  finally:
931
  # Clean up database connections
932
+ await rag.finalize_storages()
 
 
 
 
 
 
 
 
 
 
933
 
934
  # Initialize FastAPI
935
  app = FastAPI(
 
961
  allow_headers=["*"],
962
  )
963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
964
  # Create the optional API key dependency
965
  optional_api_key = get_api_key_dependency(api_key)
966
 
 
1095
  },
1096
  log_level=args.log_level,
1097
  namespace_prefix=args.namespace_prefix,
1098
+ is_managed_by_server=True,
1099
  )
1100
  else:
1101
  rag = LightRAG(
 
1127
  },
1128
  log_level=args.log_level,
1129
  namespace_prefix=args.namespace_prefix,
1130
+ is_managed_by_server=True,
1131
  )
1132
 
 
 
 
 
 
 
 
 
 
 
 
 
1133
  async def pipeline_enqueue_file(file_path: Path) -> bool:
1134
  """Add a file to the queue for processing
1135
 
lightrag/base.py CHANGED
@@ -87,6 +87,14 @@ class StorageNameSpace(ABC):
87
  namespace: str
88
  global_config: dict[str, Any]
89
 
 
 
 
 
 
 
 
 
90
  @abstractmethod
91
  async def index_done_callback(self) -> None:
92
  """Commit the storage operations after indexing"""
@@ -247,3 +255,12 @@ class DocStatusStorage(BaseKVStorage, ABC):
247
  self, status: DocStatus
248
  ) -> dict[str, DocProcessingStatus]:
249
  """Get all documents with a specific status"""
 
 
 
 
 
 
 
 
 
 
87
  namespace: str
88
  global_config: dict[str, Any]
89
 
90
+ async def initialize(self):
91
+ """Initialize the storage"""
92
+ pass
93
+
94
+ async def finalize(self):
95
+ """Finalize the storage"""
96
+ pass
97
+
98
  @abstractmethod
99
  async def index_done_callback(self) -> None:
100
  """Commit the storage operations after indexing"""
 
255
  self, status: DocStatus
256
  ) -> dict[str, DocProcessingStatus]:
257
  """Get all documents with a specific status"""
258
+
259
+
260
+ class StoragesStatus(str, Enum):
261
+ """Storages status"""
262
+
263
+ NOT_CREATED = "not_created"
264
+ CREATED = "created"
265
+ INITIALIZED = "initialized"
266
+ FINALIZED = "finalized"
lightrag/kg/oracle_impl.py CHANGED
@@ -2,11 +2,11 @@ import array
2
  import asyncio
3
 
4
  # import html
5
- # import os
6
  from dataclasses import dataclass
7
  from typing import Any, Union, final
8
-
9
  import numpy as np
 
10
 
11
  from lightrag.types import KnowledgeGraph
12
 
@@ -173,6 +173,72 @@ class OracleDB:
173
  raise
174
 
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  @final
177
  @dataclass
178
  class OracleKVStorage(BaseKVStorage):
@@ -184,6 +250,15 @@ class OracleKVStorage(BaseKVStorage):
184
  self._data = {}
185
  self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
186
 
 
 
 
 
 
 
 
 
 
187
  ################ QUERY METHODS ################
188
 
189
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -329,6 +404,15 @@ class OracleVectorDBStorage(BaseVectorStorage):
329
  )
330
  self.cosine_better_than_threshold = cosine_threshold
331
 
 
 
 
 
 
 
 
 
 
332
  #################### query method ###############
333
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
334
  embeddings = await self.embedding_func([query])
@@ -368,6 +452,15 @@ class OracleGraphStorage(BaseGraphStorage):
368
  def __post_init__(self):
369
  self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
370
 
 
 
 
 
 
 
 
 
 
371
  #################### insert method ################
372
 
373
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
 
2
  import asyncio
3
 
4
  # import html
5
+ import os
6
  from dataclasses import dataclass
7
  from typing import Any, Union, final
 
8
  import numpy as np
9
+ import configparser
10
 
11
  from lightrag.types import KnowledgeGraph
12
 
 
173
  raise
174
 
175
 
176
+ class ClientManager:
177
+ _instances = {"db": None, "ref_count": 0}
178
+ _lock = asyncio.Lock()
179
+
180
+ @staticmethod
181
+ def get_config():
182
+ config = configparser.ConfigParser()
183
+ config.read("config.ini", "utf-8")
184
+
185
+ return {
186
+ "user": os.environ.get(
187
+ "ORACLE_USER",
188
+ config.get("oracle", "user", fallback=None),
189
+ ),
190
+ "password": os.environ.get(
191
+ "ORACLE_PASSWORD",
192
+ config.get("oracle", "password", fallback=None),
193
+ ),
194
+ "dsn": os.environ.get(
195
+ "ORACLE_DSN",
196
+ config.get("oracle", "dsn", fallback=None),
197
+ ),
198
+ "config_dir": os.environ.get(
199
+ "ORACLE_CONFIG_DIR",
200
+ config.get("oracle", "config_dir", fallback=None),
201
+ ),
202
+ "wallet_location": os.environ.get(
203
+ "ORACLE_WALLET_LOCATION",
204
+ config.get("oracle", "wallet_location", fallback=None),
205
+ ),
206
+ "wallet_password": os.environ.get(
207
+ "ORACLE_WALLET_PASSWORD",
208
+ config.get("oracle", "wallet_password", fallback=None),
209
+ ),
210
+ "workspace": os.environ.get(
211
+ "ORACLE_WORKSPACE",
212
+ config.get("oracle", "workspace", fallback="default"),
213
+ ),
214
+ }
215
+
216
+ @classmethod
217
+ async def get_client(cls) -> OracleDB:
218
+ async with cls._lock:
219
+ if cls._instances["db"] is None:
220
+ config = ClientManager.get_config()
221
+ db = OracleDB(config)
222
+ await db.check_tables()
223
+ cls._instances["db"] = db
224
+ cls._instances["ref_count"] = 0
225
+ cls._instances["ref_count"] += 1
226
+ return cls._instances["db"]
227
+
228
+ @classmethod
229
+ async def release_client(cls, db: OracleDB):
230
+ async with cls._lock:
231
+ if db is not None:
232
+ if db is cls._instances["db"]:
233
+ cls._instances["ref_count"] -= 1
234
+ if cls._instances["ref_count"] == 0:
235
+ await db.pool.close()
236
+ logger.info("Closed OracleDB database connection pool")
237
+ cls._instances["db"] = None
238
+ else:
239
+ await db.pool.close()
240
+
241
+
242
  @final
243
  @dataclass
244
  class OracleKVStorage(BaseKVStorage):
 
250
  self._data = {}
251
  self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
252
 
253
+ async def initialize(self):
254
+ if not hasattr(self, "db") or self.db is None:
255
+ self.db = await ClientManager.get_client()
256
+
257
+ async def finalize(self):
258
+ if hasattr(self, "db") and self.db is not None:
259
+ await ClientManager.release_client(self.db)
260
+ self.db = None
261
+
262
  ################ QUERY METHODS ################
263
 
264
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
 
404
  )
405
  self.cosine_better_than_threshold = cosine_threshold
406
 
407
+ async def initialize(self):
408
+ if not hasattr(self, "db") or self.db is None:
409
+ self.db = await ClientManager.get_client()
410
+
411
+ async def finalize(self):
412
+ if hasattr(self, "db") and self.db is not None:
413
+ await ClientManager.release_client(self.db)
414
+ self.db = None
415
+
416
  #################### query method ###############
417
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
418
  embeddings = await self.embedding_func([query])
 
452
  def __post_init__(self):
453
  self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
454
 
455
+ async def initialize(self):
456
+ if not hasattr(self, "db") or self.db is None:
457
+ self.db = await ClientManager.get_client()
458
+
459
+ async def finalize(self):
460
+ if hasattr(self, "db") and self.db is not None:
461
+ await ClientManager.release_client(self.db)
462
+ self.db = None
463
+
464
  #################### insert method ################
465
 
466
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
lightrag/kg/postgres_impl.py CHANGED
@@ -5,8 +5,8 @@ import os
5
  import time
6
  from dataclasses import dataclass
7
  from typing import Any, Dict, List, Union, final
8
-
9
  import numpy as np
 
10
 
11
  from lightrag.types import KnowledgeGraph
12
 
@@ -182,6 +182,67 @@ class PostgreSQLDB:
182
  pass
183
 
184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  @final
186
  @dataclass
187
  class PGKVStorage(BaseKVStorage):
@@ -191,6 +252,15 @@ class PGKVStorage(BaseKVStorage):
191
  def __post_init__(self):
192
  self._max_batch_size = self.global_config["embedding_batch_num"]
193
 
 
 
 
 
 
 
 
 
 
194
  ################ QUERY METHODS ################
195
 
196
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -319,6 +389,15 @@ class PGVectorStorage(BaseVectorStorage):
319
  )
320
  self.cosine_better_than_threshold = cosine_threshold
321
 
 
 
 
 
 
 
 
 
 
322
  def _upsert_chunks(self, item: dict):
323
  try:
324
  upsert_sql = SQL_TEMPLATES["upsert_chunk"]
@@ -435,6 +514,15 @@ class PGVectorStorage(BaseVectorStorage):
435
  @final
436
  @dataclass
437
  class PGDocStatusStorage(DocStatusStorage):
 
 
 
 
 
 
 
 
 
438
  async def filter_keys(self, keys: set[str]) -> set[str]:
439
  """Filter out duplicated content"""
440
  sql = SQL_TEMPLATES["filter_keys"].format(
@@ -584,6 +672,15 @@ class PGGraphStorage(BaseGraphStorage):
584
  "node2vec": self._node2vec_embed,
585
  }
586
 
 
 
 
 
 
 
 
 
 
587
  async def index_done_callback(self) -> None:
588
  # PG handles persistence automatically
589
  pass
 
5
  import time
6
  from dataclasses import dataclass
7
  from typing import Any, Dict, List, Union, final
 
8
  import numpy as np
9
+ import configparser
10
 
11
  from lightrag.types import KnowledgeGraph
12
 
 
182
  pass
183
 
184
 
185
+ class ClientManager:
186
+ _instances = {"db": None, "ref_count": 0}
187
+ _lock = asyncio.Lock()
188
+
189
+ @staticmethod
190
+ def get_config():
191
+ config = configparser.ConfigParser()
192
+ config.read("config.ini", "utf-8")
193
+
194
+ return {
195
+ "host": os.environ.get(
196
+ "POSTGRES_HOST",
197
+ config.get("postgres", "host", fallback="localhost"),
198
+ ),
199
+ "port": os.environ.get(
200
+ "POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
201
+ ),
202
+ "user": os.environ.get(
203
+ "POSTGRES_USER", config.get("postgres", "user", fallback=None)
204
+ ),
205
+ "password": os.environ.get(
206
+ "POSTGRES_PASSWORD",
207
+ config.get("postgres", "password", fallback=None),
208
+ ),
209
+ "database": os.environ.get(
210
+ "POSTGRES_DATABASE",
211
+ config.get("postgres", "database", fallback=None),
212
+ ),
213
+ "workspace": os.environ.get(
214
+ "POSTGRES_WORKSPACE",
215
+ config.get("postgres", "workspace", fallback="default"),
216
+ ),
217
+ }
218
+
219
+ @classmethod
220
+ async def get_client(cls) -> PostgreSQLDB:
221
+ async with cls._lock:
222
+ if cls._instances["db"] is None:
223
+ config = ClientManager.get_config()
224
+ db = PostgreSQLDB(config)
225
+ await db.initdb()
226
+ await db.check_tables()
227
+ cls._instances["db"] = db
228
+ cls._instances["ref_count"] = 0
229
+ cls._instances["ref_count"] += 1
230
+ return cls._instances["db"]
231
+
232
+ @classmethod
233
+ async def release_client(cls, db: PostgreSQLDB):
234
+ async with cls._lock:
235
+ if db is not None:
236
+ if db is cls._instances["db"]:
237
+ cls._instances["ref_count"] -= 1
238
+ if cls._instances["ref_count"] == 0:
239
+ await db.pool.close()
240
+ logger.info("Closed PostgreSQL database connection pool")
241
+ cls._instances["db"] = None
242
+ else:
243
+ await db.pool.close()
244
+
245
+
246
  @final
247
  @dataclass
248
  class PGKVStorage(BaseKVStorage):
 
252
  def __post_init__(self):
253
  self._max_batch_size = self.global_config["embedding_batch_num"]
254
 
255
+ async def initialize(self):
256
+ if not hasattr(self, "db") or self.db is None:
257
+ self.db = await ClientManager.get_client()
258
+
259
+ async def finalize(self):
260
+ if hasattr(self, "db") and self.db is not None:
261
+ await ClientManager.release_client(self.db)
262
+ self.db = None
263
+
264
  ################ QUERY METHODS ################
265
 
266
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
 
389
  )
390
  self.cosine_better_than_threshold = cosine_threshold
391
 
392
+ async def initialize(self):
393
+ if not hasattr(self, "db") or self.db is None:
394
+ self.db = await ClientManager.get_client()
395
+
396
+ async def finalize(self):
397
+ if hasattr(self, "db") and self.db is not None:
398
+ await ClientManager.release_client(self.db)
399
+ self.db = None
400
+
401
  def _upsert_chunks(self, item: dict):
402
  try:
403
  upsert_sql = SQL_TEMPLATES["upsert_chunk"]
 
514
  @final
515
  @dataclass
516
  class PGDocStatusStorage(DocStatusStorage):
517
+ async def initialize(self):
518
+ if not hasattr(self, "db") or self.db is None:
519
+ self.db = await ClientManager.get_client()
520
+
521
+ async def finalize(self):
522
+ if hasattr(self, "db") and self.db is not None:
523
+ await ClientManager.release_client(self.db)
524
+ self.db = None
525
+
526
  async def filter_keys(self, keys: set[str]) -> set[str]:
527
  """Filter out duplicated content"""
528
  sql = SQL_TEMPLATES["filter_keys"].format(
 
672
  "node2vec": self._node2vec_embed,
673
  }
674
 
675
+ async def initialize(self):
676
+ if not hasattr(self, "db") or self.db is None:
677
+ self.db = await ClientManager.get_client()
678
+
679
+ async def finalize(self):
680
+ if hasattr(self, "db") and self.db is not None:
681
+ await ClientManager.release_client(self.db)
682
+ self.db = None
683
+
684
  async def index_done_callback(self) -> None:
685
  # PG handles persistence automatically
686
  pass
lightrag/kg/tidb_impl.py CHANGED
@@ -14,6 +14,7 @@ from ..namespace import NameSpace, is_namespace
14
  from ..utils import logger
15
 
16
  import pipmaster as pm
 
17
 
18
  if not pm.is_installed("pymysql"):
19
  pm.install("pymysql")
@@ -105,6 +106,63 @@ class TiDB:
105
  raise
106
 
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  @final
109
  @dataclass
110
  class TiDBKVStorage(BaseKVStorage):
@@ -115,6 +173,15 @@ class TiDBKVStorage(BaseKVStorage):
115
  self._data = {}
116
  self._max_batch_size = self.global_config["embedding_batch_num"]
117
 
 
 
 
 
 
 
 
 
 
118
  ################ QUERY METHODS ################
119
 
120
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -185,7 +252,7 @@ class TiDBKVStorage(BaseKVStorage):
185
  "tokens": item["tokens"],
186
  "chunk_order_index": item["chunk_order_index"],
187
  "full_doc_id": item["full_doc_id"],
188
- "content_vector": f'{item["__vector__"].tolist()}',
189
  "workspace": self.db.workspace,
190
  }
191
  )
@@ -226,6 +293,15 @@ class TiDBVectorDBStorage(BaseVectorStorage):
226
  )
227
  self.cosine_better_than_threshold = cosine_threshold
228
 
 
 
 
 
 
 
 
 
 
229
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
230
  """Search from tidb vector"""
231
  embeddings = await self.embedding_func([query])
@@ -290,7 +366,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
290
  "id": item["id"],
291
  "name": item["entity_name"],
292
  "content": item["content"],
293
- "content_vector": f'{item["content_vector"].tolist()}',
294
  "workspace": self.db.workspace,
295
  }
296
  # update entity_id if node inserted by graph_storage_instance before
@@ -312,7 +388,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
312
  "source_name": item["src_id"],
313
  "target_name": item["tgt_id"],
314
  "content": item["content"],
315
- "content_vector": f'{item["content_vector"].tolist()}',
316
  "workspace": self.db.workspace,
317
  }
318
  # update relation_id if node inserted by graph_storage_instance before
@@ -351,6 +427,15 @@ class TiDBGraphStorage(BaseGraphStorage):
351
  def __post_init__(self):
352
  self._max_batch_size = self.global_config["embedding_batch_num"]
353
 
 
 
 
 
 
 
 
 
 
354
  #################### upsert method ################
355
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
356
  entity_name = node_id
 
14
  from ..utils import logger
15
 
16
  import pipmaster as pm
17
+ import configparser
18
 
19
  if not pm.is_installed("pymysql"):
20
  pm.install("pymysql")
 
106
  raise
107
 
108
 
109
+ class ClientManager:
110
+ _instances = {"db": None, "ref_count": 0}
111
+ _lock = asyncio.Lock()
112
+
113
+ @staticmethod
114
+ def get_config():
115
+ config = configparser.ConfigParser()
116
+ config.read("config.ini", "utf-8")
117
+
118
+ return {
119
+ "host": os.environ.get(
120
+ "TIDB_HOST",
121
+ config.get("tidb", "host", fallback="localhost"),
122
+ ),
123
+ "port": os.environ.get(
124
+ "TIDB_PORT", config.get("tidb", "port", fallback=4000)
125
+ ),
126
+ "user": os.environ.get(
127
+ "TIDB_USER",
128
+ config.get("tidb", "user", fallback=None),
129
+ ),
130
+ "password": os.environ.get(
131
+ "TIDB_PASSWORD",
132
+ config.get("tidb", "password", fallback=None),
133
+ ),
134
+ "database": os.environ.get(
135
+ "TIDB_DATABASE",
136
+ config.get("tidb", "database", fallback=None),
137
+ ),
138
+ "workspace": os.environ.get(
139
+ "TIDB_WORKSPACE",
140
+ config.get("tidb", "workspace", fallback="default"),
141
+ ),
142
+ }
143
+
144
+ @classmethod
145
+ async def get_client(cls) -> TiDB:
146
+ async with cls._lock:
147
+ if cls._instances["db"] is None:
148
+ config = ClientManager.get_config()
149
+ db = TiDB(config)
150
+ await db.check_tables()
151
+ cls._instances["db"] = db
152
+ cls._instances["ref_count"] = 0
153
+ cls._instances["ref_count"] += 1
154
+ return cls._instances["db"]
155
+
156
+ @classmethod
157
+ async def release_client(cls, db: TiDB):
158
+ async with cls._lock:
159
+ if db is not None:
160
+ if db is cls._instances["db"]:
161
+ cls._instances["ref_count"] -= 1
162
+ if cls._instances["ref_count"] == 0:
163
+ cls._instances["db"] = None
164
+
165
+
166
  @final
167
  @dataclass
168
  class TiDBKVStorage(BaseKVStorage):
 
173
  self._data = {}
174
  self._max_batch_size = self.global_config["embedding_batch_num"]
175
 
176
+ async def initialize(self):
177
+ if not hasattr(self, "db") or self.db is None:
178
+ self.db = await ClientManager.get_client()
179
+
180
+ async def finalize(self):
181
+ if hasattr(self, "db") and self.db is not None:
182
+ await ClientManager.release_client(self.db)
183
+ self.db = None
184
+
185
  ################ QUERY METHODS ################
186
 
187
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
 
252
  "tokens": item["tokens"],
253
  "chunk_order_index": item["chunk_order_index"],
254
  "full_doc_id": item["full_doc_id"],
255
+ "content_vector": f"{item['__vector__'].tolist()}",
256
  "workspace": self.db.workspace,
257
  }
258
  )
 
293
  )
294
  self.cosine_better_than_threshold = cosine_threshold
295
 
296
+ async def initialize(self):
297
+ if not hasattr(self, "db") or self.db is None:
298
+ self.db = await ClientManager.get_client()
299
+
300
+ async def finalize(self):
301
+ if hasattr(self, "db") and self.db is not None:
302
+ await ClientManager.release_client(self.db)
303
+ self.db = None
304
+
305
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
306
  """Search from tidb vector"""
307
  embeddings = await self.embedding_func([query])
 
366
  "id": item["id"],
367
  "name": item["entity_name"],
368
  "content": item["content"],
369
+ "content_vector": f"{item['content_vector'].tolist()}",
370
  "workspace": self.db.workspace,
371
  }
372
  # update entity_id if node inserted by graph_storage_instance before
 
388
  "source_name": item["src_id"],
389
  "target_name": item["tgt_id"],
390
  "content": item["content"],
391
+ "content_vector": f"{item['content_vector'].tolist()}",
392
  "workspace": self.db.workspace,
393
  }
394
  # update relation_id if node inserted by graph_storage_instance before
 
427
  def __post_init__(self):
428
  self._max_batch_size = self.global_config["embedding_batch_num"]
429
 
430
+ async def initialize(self):
431
+ if not hasattr(self, "db") or self.db is None:
432
+ self.db = await ClientManager.get_client()
433
+
434
+ async def finalize(self):
435
+ if hasattr(self, "db") and self.db is not None:
436
+ await ClientManager.release_client(self.db)
437
+ self.db = None
438
+
439
  #################### upsert method ################
440
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
441
  entity_name = node_id
lightrag/lightrag.py CHANGED
@@ -17,6 +17,7 @@ from .base import (
17
  DocStatusStorage,
18
  QueryParam,
19
  StorageNameSpace,
 
20
  )
21
  from .namespace import NameSpace, make_namespace
22
  from .operate import (
@@ -348,6 +349,9 @@ class LightRAG:
348
  # Extensions
349
  addon_params: dict[str, Any] = field(default_factory=dict)
350
 
 
 
 
351
  """Dictionary for additional parameters and extensions."""
352
  convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
353
  convert_response_to_json
@@ -440,7 +444,10 @@ class LightRAG:
440
  **self.vector_db_storage_cls_kwargs,
441
  }
442
 
443
- # show config
 
 
 
444
  global_config = asdict(self)
445
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
446
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@@ -547,6 +554,65 @@ class LightRAG:
547
  )
548
  )
549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  async def get_graph_labels(self):
551
  text = await self.chunk_entity_relation_graph.get_all_labels()
552
  return text
 
17
  DocStatusStorage,
18
  QueryParam,
19
  StorageNameSpace,
20
+ StoragesStatus,
21
  )
22
  from .namespace import NameSpace, make_namespace
23
  from .operate import (
 
349
  # Extensions
350
  addon_params: dict[str, Any] = field(default_factory=dict)
351
 
352
+ # Ownership
353
+ is_managed_by_server: bool = False
354
+
355
  """Dictionary for additional parameters and extensions."""
356
  convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
357
  convert_response_to_json
 
444
  **self.vector_db_storage_cls_kwargs,
445
  }
446
 
447
+ # Life cycle
448
+ self.storages_status = StoragesStatus.NOT_CREATED
449
+
450
+ # Show config
451
  global_config = asdict(self)
452
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
453
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
 
554
  )
555
  )
556
 
557
+ self.storages_status = StoragesStatus.CREATED
558
+
559
+ # Initialize storages
560
+ if not self.is_managed_by_server:
561
+ loop = always_get_an_event_loop()
562
+ loop.run_until_complete(self.initialize_storages())
563
+
564
+ def __del__(self):
565
+ # Finalize storages
566
+ if not self.is_managed_by_server:
567
+ loop = always_get_an_event_loop()
568
+ loop.run_until_complete(self.finalize_storages())
569
+
570
+ async def initialize_storages(self):
571
+ """Asynchronously initialize the storages"""
572
+ if self.storages_status == StoragesStatus.CREATED:
573
+ tasks = []
574
+
575
+ for storage in (
576
+ self.full_docs,
577
+ self.text_chunks,
578
+ self.entities_vdb,
579
+ self.relationships_vdb,
580
+ self.chunks_vdb,
581
+ self.chunk_entity_relation_graph,
582
+ self.llm_response_cache,
583
+ self.doc_status,
584
+ ):
585
+ if storage:
586
+ tasks.append(storage.initialize())
587
+
588
+ await asyncio.gather(*tasks)
589
+
590
+ self.storages_status = StoragesStatus.INITIALIZED
591
+ logger.debug("Initialized Storages")
592
+
593
+ async def finalize_storages(self):
594
+ """Asynchronously finalize the storages"""
595
+ if self.storages_status == StoragesStatus.INITIALIZED:
596
+ tasks = []
597
+
598
+ for storage in (
599
+ self.full_docs,
600
+ self.text_chunks,
601
+ self.entities_vdb,
602
+ self.relationships_vdb,
603
+ self.chunks_vdb,
604
+ self.chunk_entity_relation_graph,
605
+ self.llm_response_cache,
606
+ self.doc_status,
607
+ ):
608
+ if storage:
609
+ tasks.append(storage.finalize())
610
+
611
+ await asyncio.gather(*tasks)
612
+ logger.debug("Finalized Storages")
613
+
614
+ self.storages_status = StoragesStatus.FINALIZED
615
+
616
  async def get_graph_labels(self):
617
  text = await self.chunk_entity_relation_graph.get_all_labels()
618
  return text