YanSte commited on
Commit
2d5998f
·
unverified ·
2 Parent(s): 5895546 0812991

Merge pull request #846 from ArnoChenFx/db-connection-and-storage-lifecycle

Browse files

Refactor Database Connection Management and Improve Storage Lifecycle Handling

examples/lightrag_api_oracle_demo.py CHANGED
@@ -17,7 +17,6 @@ from lightrag.llm.openai import openai_complete_if_cache, openai_embed
17
  from lightrag.utils import EmbeddingFunc
18
  import numpy as np
19
 
20
- from lightrag.kg.oracle_impl import OracleDB
21
 
22
  print(os.getcwd())
23
  script_directory = Path(__file__).resolve().parent.parent
@@ -48,6 +47,14 @@ print(f"EMBEDDING_MAX_TOKEN_SIZE: {EMBEDDING_MAX_TOKEN_SIZE}")
48
  if not os.path.exists(WORKING_DIR):
49
  os.mkdir(WORKING_DIR)
50
 
 
 
 
 
 
 
 
 
51
 
52
  async def llm_model_func(
53
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
@@ -89,20 +96,6 @@ async def init():
89
  # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
90
  # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
91
 
92
- oracle_db = OracleDB(
93
- config={
94
- "user": "",
95
- "password": "",
96
- "dsn": "",
97
- "config_dir": "path_to_config_dir",
98
- "wallet_location": "path_to_wallet_location",
99
- "wallet_password": "wallet_password",
100
- "workspace": "company",
101
- } # specify which docs you want to store and query
102
- )
103
-
104
- # Check if Oracle DB tables exist, if not, tables will be created
105
- await oracle_db.check_tables()
106
  # Initialize LightRAG
107
  # We use Oracle DB as the KV/vector/graph storage
108
  # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
@@ -121,11 +114,6 @@ async def init():
121
  vector_storage="OracleVectorDBStorage",
122
  )
123
 
124
- # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
125
- rag.graph_storage_cls.db = oracle_db
126
- rag.key_string_value_json_storage_cls.db = oracle_db
127
- rag.vector_db_storage_cls.db = oracle_db
128
-
129
  return rag
130
 
131
 
 
17
  from lightrag.utils import EmbeddingFunc
18
  import numpy as np
19
 
 
20
 
21
  print(os.getcwd())
22
  script_directory = Path(__file__).resolve().parent.parent
 
47
  if not os.path.exists(WORKING_DIR):
48
  os.mkdir(WORKING_DIR)
49
 
50
+ os.environ["ORACLE_USER"] = ""
51
+ os.environ["ORACLE_PASSWORD"] = ""
52
+ os.environ["ORACLE_DSN"] = ""
53
+ os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
54
+ os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
55
+ os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
56
+ os.environ["ORACLE_WORKSPACE"] = "company"
57
+
58
 
59
  async def llm_model_func(
60
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
 
96
  # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
97
  # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  # Initialize LightRAG
100
  # We use Oracle DB as the KV/vector/graph storage
101
  # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
 
114
  vector_storage="OracleVectorDBStorage",
115
  )
116
 
 
 
 
 
 
117
  return rag
118
 
119
 
examples/lightrag_oracle_demo.py CHANGED
@@ -6,7 +6,6 @@ from lightrag import LightRAG, QueryParam
6
  from lightrag.llm.openai import openai_complete_if_cache, openai_embed
7
  from lightrag.utils import EmbeddingFunc
8
  import numpy as np
9
- from lightrag.kg.oracle_impl import OracleDB
10
 
11
  print(os.getcwd())
12
  script_directory = Path(__file__).resolve().parent.parent
@@ -26,6 +25,14 @@ MAX_TOKENS = 4000
26
  if not os.path.exists(WORKING_DIR):
27
  os.mkdir(WORKING_DIR)
28
 
 
 
 
 
 
 
 
 
29
 
30
  async def llm_model_func(
31
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
@@ -63,26 +70,6 @@ async def main():
63
  embedding_dimension = await get_embedding_dim()
64
  print(f"Detected embedding dimension: {embedding_dimension}")
65
 
66
- # Create Oracle DB connection
67
- # The `config` parameter is the connection configuration of Oracle DB
68
- # More docs here https://python-oracledb.readthedocs.io/en/latest/user_guide/connection_handling.html
69
- # We storage data in unified tables, so we need to set a `workspace` parameter to specify which docs we want to store and query
70
- # Below is an example of how to connect to Oracle Autonomous Database on Oracle Cloud
71
- oracle_db = OracleDB(
72
- config={
73
- "user": "username",
74
- "password": "xxxxxxxxx",
75
- "dsn": "xxxxxxx_medium",
76
- "config_dir": "dir/path/to/oracle/config",
77
- "wallet_location": "dir/path/to/oracle/wallet",
78
- "wallet_password": "xxxxxxxxx",
79
- "workspace": "company", # specify which docs you want to store and query
80
- }
81
- )
82
-
83
- # Check if Oracle DB tables exist, if not, tables will be created
84
- await oracle_db.check_tables()
85
-
86
  # Initialize LightRAG
87
  # We use Oracle DB as the KV/vector/graph storage
88
  # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
@@ -112,26 +99,6 @@ async def main():
112
  },
113
  )
114
 
115
- # Setthe KV/vector/graph storage's `db` property, so all operation will use same connection pool
116
-
117
- for storage in [
118
- rag.vector_db_storage_cls,
119
- rag.graph_storage_cls,
120
- rag.doc_status,
121
- rag.full_docs,
122
- rag.text_chunks,
123
- rag.llm_response_cache,
124
- rag.key_string_value_json_storage_cls,
125
- rag.chunks_vdb,
126
- rag.relationships_vdb,
127
- rag.entities_vdb,
128
- rag.graph_storage_cls,
129
- rag.chunk_entity_relation_graph,
130
- rag.llm_response_cache,
131
- ]:
132
- # set client
133
- storage.db = oracle_db
134
-
135
  # Extract and Insert into LightRAG storage
136
  with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
137
  all_text = f.read()
 
6
  from lightrag.llm.openai import openai_complete_if_cache, openai_embed
7
  from lightrag.utils import EmbeddingFunc
8
  import numpy as np
 
9
 
10
  print(os.getcwd())
11
  script_directory = Path(__file__).resolve().parent.parent
 
25
  if not os.path.exists(WORKING_DIR):
26
  os.mkdir(WORKING_DIR)
27
 
28
+ os.environ["ORACLE_USER"] = "username"
29
+ os.environ["ORACLE_PASSWORD"] = "xxxxxxxxx"
30
+ os.environ["ORACLE_DSN"] = "xxxxxxx_medium"
31
+ os.environ["ORACLE_CONFIG_DIR"] = "path_to_config_dir"
32
+ os.environ["ORACLE_WALLET_LOCATION"] = "path_to_wallet_location"
33
+ os.environ["ORACLE_WALLET_PASSWORD"] = "wallet_password"
34
+ os.environ["ORACLE_WORKSPACE"] = "company"
35
+
36
 
37
  async def llm_model_func(
38
  prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
 
70
  embedding_dimension = await get_embedding_dim()
71
  print(f"Detected embedding dimension: {embedding_dimension}")
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  # Initialize LightRAG
74
  # We use Oracle DB as the KV/vector/graph storage
75
  # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
 
99
  },
100
  )
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  # Extract and Insert into LightRAG storage
103
  with open(WORKING_DIR + "/docs.txt", "r", encoding="utf-8") as f:
104
  all_text = f.read()
examples/lightrag_tidb_demo.py CHANGED
@@ -4,7 +4,6 @@ import os
4
  import numpy as np
5
 
6
  from lightrag import LightRAG, QueryParam
7
- from lightrag.kg.tidb_impl import TiDB
8
  from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache
9
  from lightrag.utils import EmbeddingFunc
10
 
@@ -17,11 +16,11 @@ APIKEY = ""
17
  CHATMODEL = ""
18
  EMBEDMODEL = ""
19
 
20
- TIDB_HOST = ""
21
- TIDB_PORT = ""
22
- TIDB_USER = ""
23
- TIDB_PASSWORD = ""
24
- TIDB_DATABASE = "lightrag"
25
 
26
  if not os.path.exists(WORKING_DIR):
27
  os.mkdir(WORKING_DIR)
@@ -62,21 +61,6 @@ async def main():
62
  embedding_dimension = await get_embedding_dim()
63
  print(f"Detected embedding dimension: {embedding_dimension}")
64
 
65
- # Create TiDB DB connection
66
- tidb = TiDB(
67
- config={
68
- "host": TIDB_HOST,
69
- "port": TIDB_PORT,
70
- "user": TIDB_USER,
71
- "password": TIDB_PASSWORD,
72
- "database": TIDB_DATABASE,
73
- "workspace": "company", # specify which docs you want to store and query
74
- }
75
- )
76
-
77
- # Check if TiDB DB tables exist, if not, tables will be created
78
- await tidb.check_tables()
79
-
80
  # Initialize LightRAG
81
  # We use TiDB DB as the KV/vector
82
  # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
@@ -95,15 +79,6 @@ async def main():
95
  graph_storage="TiDBGraphStorage",
96
  )
97
 
98
- if rag.llm_response_cache:
99
- rag.llm_response_cache.db = tidb
100
- rag.full_docs.db = tidb
101
- rag.text_chunks.db = tidb
102
- rag.entities_vdb.db = tidb
103
- rag.relationships_vdb.db = tidb
104
- rag.chunks_vdb.db = tidb
105
- rag.chunk_entity_relation_graph.db = tidb
106
-
107
  # Extract and Insert into LightRAG storage
108
  with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
109
  await rag.ainsert(f.read())
 
4
  import numpy as np
5
 
6
  from lightrag import LightRAG, QueryParam
 
7
  from lightrag.llm import siliconcloud_embedding, openai_complete_if_cache
8
  from lightrag.utils import EmbeddingFunc
9
 
 
16
  CHATMODEL = ""
17
  EMBEDMODEL = ""
18
 
19
+ os.environ["TIDB_HOST"] = ""
20
+ os.environ["TIDB_PORT"] = ""
21
+ os.environ["TIDB_USER"] = ""
22
+ os.environ["TIDB_PASSWORD"] = ""
23
+ os.environ["TIDB_DATABASE"] = "lightrag"
24
 
25
  if not os.path.exists(WORKING_DIR):
26
  os.mkdir(WORKING_DIR)
 
61
  embedding_dimension = await get_embedding_dim()
62
  print(f"Detected embedding dimension: {embedding_dimension}")
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # Initialize LightRAG
65
  # We use TiDB DB as the KV/vector
66
  # You can add `addon_params={"example_number": 1, "language": "Simplfied Chinese"}` to control the prompt
 
79
  graph_storage="TiDBGraphStorage",
80
  )
81
 
 
 
 
 
 
 
 
 
 
82
  # Extract and Insert into LightRAG storage
83
  with open("./dickens/demo.txt", "r", encoding="utf-8") as f:
84
  await rag.ainsert(f.read())
examples/lightrag_zhipu_postgres_demo.py CHANGED
@@ -5,7 +5,6 @@ import time
5
  from dotenv import load_dotenv
6
 
7
  from lightrag import LightRAG, QueryParam
8
- from lightrag.kg.postgres_impl import PostgreSQLDB
9
  from lightrag.llm.zhipu import zhipu_complete
10
  from lightrag.llm.ollama import ollama_embedding
11
  from lightrag.utils import EmbeddingFunc
@@ -22,22 +21,14 @@ if not os.path.exists(WORKING_DIR):
22
  # AGE
23
  os.environ["AGE_GRAPH_NAME"] = "dickens"
24
 
25
- postgres_db = PostgreSQLDB(
26
- config={
27
- "host": "localhost",
28
- "port": 15432,
29
- "user": "rag",
30
- "password": "rag",
31
- "database": "rag",
32
- }
33
- )
34
 
35
 
36
  async def main():
37
- await postgres_db.initdb()
38
- # Check if PostgreSQL DB tables exist, if not, tables will be created
39
- await postgres_db.check_tables()
40
-
41
  rag = LightRAG(
42
  working_dir=WORKING_DIR,
43
  llm_model_func=zhipu_complete,
@@ -57,17 +48,7 @@ async def main():
57
  graph_storage="PGGraphStorage",
58
  vector_storage="PGVectorStorage",
59
  )
60
- # Set the KV/vector/graph storage's `db` property, so all operation will use same connection pool
61
- rag.doc_status.db = postgres_db
62
- rag.full_docs.db = postgres_db
63
- rag.text_chunks.db = postgres_db
64
- rag.llm_response_cache.db = postgres_db
65
- rag.key_string_value_json_storage_cls.db = postgres_db
66
- rag.chunks_vdb.db = postgres_db
67
- rag.relationships_vdb.db = postgres_db
68
- rag.entities_vdb.db = postgres_db
69
- rag.graph_storage_cls.db = postgres_db
70
- rag.chunk_entity_relation_graph.db = postgres_db
71
  # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
72
  rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
73
 
 
5
  from dotenv import load_dotenv
6
 
7
  from lightrag import LightRAG, QueryParam
 
8
  from lightrag.llm.zhipu import zhipu_complete
9
  from lightrag.llm.ollama import ollama_embedding
10
  from lightrag.utils import EmbeddingFunc
 
21
  # AGE
22
  os.environ["AGE_GRAPH_NAME"] = "dickens"
23
 
24
+ os.environ["POSTGRES_HOST"] = "localhost"
25
+ os.environ["POSTGRES_PORT"] = "15432"
26
+ os.environ["POSTGRES_USER"] = "rag"
27
+ os.environ["POSTGRES_PASSWORD"] = "rag"
28
+ os.environ["POSTGRES_DATABASE"] = "rag"
 
 
 
 
29
 
30
 
31
  async def main():
 
 
 
 
32
  rag = LightRAG(
33
  working_dir=WORKING_DIR,
34
  llm_model_func=zhipu_complete,
 
48
  graph_storage="PGGraphStorage",
49
  vector_storage="PGVectorStorage",
50
  )
51
+
 
 
 
 
 
 
 
 
 
 
52
  # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
53
  rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func
54
 
lightrag/api/lightrag_server.py CHANGED
@@ -15,11 +15,6 @@ import logging
15
  import argparse
16
  from typing import List, Any, Literal, Optional, Dict
17
  from pydantic import BaseModel, Field, field_validator
18
- from lightrag import LightRAG, QueryParam
19
- from lightrag.base import DocProcessingStatus, DocStatus
20
- from lightrag.types import GPTKeywordExtractionFormat
21
- from lightrag.api import __api_version__
22
- from lightrag.utils import EmbeddingFunc
23
  from pathlib import Path
24
  import shutil
25
  import aiofiles
@@ -36,39 +31,13 @@ import configparser
36
  import traceback
37
  from datetime import datetime
38
 
 
 
 
 
 
39
  from lightrag.utils import logger
40
- from .ollama_api import (
41
- OllamaAPI,
42
- )
43
- from .ollama_api import ollama_server_infos
44
-
45
-
46
- def get_db_type_from_storage_class(class_name: str) -> str | None:
47
- """Determine database type based on storage class name"""
48
- if class_name.startswith("PG"):
49
- return "postgres"
50
- elif class_name.startswith("Oracle"):
51
- return "oracle"
52
- elif class_name.startswith("TiDB"):
53
- return "tidb"
54
- return None
55
-
56
-
57
- def import_db_module(db_type: str):
58
- """Dynamically import database module"""
59
- if db_type == "postgres":
60
- from ..kg.postgres_impl import PostgreSQLDB
61
-
62
- return PostgreSQLDB
63
- elif db_type == "oracle":
64
- from ..kg.oracle_impl import OracleDB
65
-
66
- return OracleDB
67
- elif db_type == "tidb":
68
- from ..kg.tidb_impl import TiDB
69
-
70
- return TiDB
71
- return None
72
 
73
 
74
  # Load environment variables
@@ -929,52 +898,12 @@ def create_app(args):
929
  @asynccontextmanager
930
  async def lifespan(app: FastAPI):
931
  """Lifespan context manager for startup and shutdown events"""
932
- # Initialize database connections
933
- db_instances = {}
934
  # Store background tasks
935
  app.state.background_tasks = set()
936
 
937
  try:
938
- # Check which database types are used
939
- db_types = set()
940
- for storage_name, storage_instance in storage_instances:
941
- db_type = get_db_type_from_storage_class(
942
- storage_instance.__class__.__name__
943
- )
944
- if db_type:
945
- db_types.add(db_type)
946
-
947
- # Import and initialize databases as needed
948
- for db_type in db_types:
949
- if db_type == "postgres":
950
- DB = import_db_module("postgres")
951
- db = DB(_get_postgres_config())
952
- await db.initdb()
953
- await db.check_tables()
954
- db_instances["postgres"] = db
955
- elif db_type == "oracle":
956
- DB = import_db_module("oracle")
957
- db = DB(_get_oracle_config())
958
- await db.check_tables()
959
- db_instances["oracle"] = db
960
- elif db_type == "tidb":
961
- DB = import_db_module("tidb")
962
- db = DB(_get_tidb_config())
963
- await db.check_tables()
964
- db_instances["tidb"] = db
965
-
966
- # Inject database instances into storage classes
967
- for storage_name, storage_instance in storage_instances:
968
- db_type = get_db_type_from_storage_class(
969
- storage_instance.__class__.__name__
970
- )
971
- if db_type:
972
- if db_type not in db_instances:
973
- error_msg = f"Database type '{db_type}' is required by {storage_name} but not initialized"
974
- logger.error(error_msg)
975
- raise RuntimeError(error_msg)
976
- storage_instance.db = db_instances[db_type]
977
- logger.info(f"Injected {db_type} db to {storage_name}")
978
 
979
  # Auto scan documents if enabled
980
  if args.auto_scan_at_startup:
@@ -1000,17 +929,7 @@ def create_app(args):
1000
 
1001
  finally:
1002
  # Clean up database connections
1003
- for db_type, db in db_instances.items():
1004
- if hasattr(db, "pool"):
1005
- await db.pool.close()
1006
- # Use more accurate database name display
1007
- db_names = {
1008
- "postgres": "PostgreSQL",
1009
- "oracle": "Oracle",
1010
- "tidb": "TiDB",
1011
- }
1012
- db_name = db_names.get(db_type, db_type)
1013
- logger.info(f"Closed {db_name} database connection pool")
1014
 
1015
  # Initialize FastAPI
1016
  app = FastAPI(
@@ -1042,92 +961,6 @@ def create_app(args):
1042
  allow_headers=["*"],
1043
  )
1044
 
1045
- # Database configuration functions
1046
- def _get_postgres_config():
1047
- return {
1048
- "host": os.environ.get(
1049
- "POSTGRES_HOST",
1050
- config.get("postgres", "host", fallback="localhost"),
1051
- ),
1052
- "port": os.environ.get(
1053
- "POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
1054
- ),
1055
- "user": os.environ.get(
1056
- "POSTGRES_USER", config.get("postgres", "user", fallback=None)
1057
- ),
1058
- "password": os.environ.get(
1059
- "POSTGRES_PASSWORD",
1060
- config.get("postgres", "password", fallback=None),
1061
- ),
1062
- "database": os.environ.get(
1063
- "POSTGRES_DATABASE",
1064
- config.get("postgres", "database", fallback=None),
1065
- ),
1066
- "workspace": os.environ.get(
1067
- "POSTGRES_WORKSPACE",
1068
- config.get("postgres", "workspace", fallback="default"),
1069
- ),
1070
- }
1071
-
1072
- def _get_oracle_config():
1073
- return {
1074
- "user": os.environ.get(
1075
- "ORACLE_USER",
1076
- config.get("oracle", "user", fallback=None),
1077
- ),
1078
- "password": os.environ.get(
1079
- "ORACLE_PASSWORD",
1080
- config.get("oracle", "password", fallback=None),
1081
- ),
1082
- "dsn": os.environ.get(
1083
- "ORACLE_DSN",
1084
- config.get("oracle", "dsn", fallback=None),
1085
- ),
1086
- "config_dir": os.environ.get(
1087
- "ORACLE_CONFIG_DIR",
1088
- config.get("oracle", "config_dir", fallback=None),
1089
- ),
1090
- "wallet_location": os.environ.get(
1091
- "ORACLE_WALLET_LOCATION",
1092
- config.get("oracle", "wallet_location", fallback=None),
1093
- ),
1094
- "wallet_password": os.environ.get(
1095
- "ORACLE_WALLET_PASSWORD",
1096
- config.get("oracle", "wallet_password", fallback=None),
1097
- ),
1098
- "workspace": os.environ.get(
1099
- "ORACLE_WORKSPACE",
1100
- config.get("oracle", "workspace", fallback="default"),
1101
- ),
1102
- }
1103
-
1104
- def _get_tidb_config():
1105
- return {
1106
- "host": os.environ.get(
1107
- "TIDB_HOST",
1108
- config.get("tidb", "host", fallback="localhost"),
1109
- ),
1110
- "port": os.environ.get(
1111
- "TIDB_PORT", config.get("tidb", "port", fallback=4000)
1112
- ),
1113
- "user": os.environ.get(
1114
- "TIDB_USER",
1115
- config.get("tidb", "user", fallback=None),
1116
- ),
1117
- "password": os.environ.get(
1118
- "TIDB_PASSWORD",
1119
- config.get("tidb", "password", fallback=None),
1120
- ),
1121
- "database": os.environ.get(
1122
- "TIDB_DATABASE",
1123
- config.get("tidb", "database", fallback=None),
1124
- ),
1125
- "workspace": os.environ.get(
1126
- "TIDB_WORKSPACE",
1127
- config.get("tidb", "workspace", fallback="default"),
1128
- ),
1129
- }
1130
-
1131
  # Create the optional API key dependency
1132
  optional_api_key = get_api_key_dependency(api_key)
1133
 
@@ -1262,6 +1095,7 @@ def create_app(args):
1262
  },
1263
  log_level=args.log_level,
1264
  namespace_prefix=args.namespace_prefix,
 
1265
  )
1266
  else:
1267
  rag = LightRAG(
@@ -1293,20 +1127,9 @@ def create_app(args):
1293
  },
1294
  log_level=args.log_level,
1295
  namespace_prefix=args.namespace_prefix,
 
1296
  )
1297
 
1298
- # Collect all storage instances
1299
- storage_instances = [
1300
- ("full_docs", rag.full_docs),
1301
- ("text_chunks", rag.text_chunks),
1302
- ("chunk_entity_relation_graph", rag.chunk_entity_relation_graph),
1303
- ("entities_vdb", rag.entities_vdb),
1304
- ("relationships_vdb", rag.relationships_vdb),
1305
- ("chunks_vdb", rag.chunks_vdb),
1306
- ("doc_status", rag.doc_status),
1307
- ("llm_response_cache", rag.llm_response_cache),
1308
- ]
1309
-
1310
  async def pipeline_enqueue_file(file_path: Path) -> bool:
1311
  """Add a file to the queue for processing
1312
 
 
15
  import argparse
16
  from typing import List, Any, Literal, Optional, Dict
17
  from pydantic import BaseModel, Field, field_validator
 
 
 
 
 
18
  from pathlib import Path
19
  import shutil
20
  import aiofiles
 
31
  import traceback
32
  from datetime import datetime
33
 
34
+ from lightrag import LightRAG, QueryParam
35
+ from lightrag.base import DocProcessingStatus, DocStatus
36
+ from lightrag.types import GPTKeywordExtractionFormat
37
+ from lightrag.api import __api_version__
38
+ from lightrag.utils import EmbeddingFunc
39
  from lightrag.utils import logger
40
+ from .ollama_api import OllamaAPI, ollama_server_infos
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  # Load environment variables
 
898
  @asynccontextmanager
899
  async def lifespan(app: FastAPI):
900
  """Lifespan context manager for startup and shutdown events"""
 
 
901
  # Store background tasks
902
  app.state.background_tasks = set()
903
 
904
  try:
905
+ # Initialize database connections
906
+ await rag.initialize_storages()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
907
 
908
  # Auto scan documents if enabled
909
  if args.auto_scan_at_startup:
 
929
 
930
  finally:
931
  # Clean up database connections
932
+ await rag.finalize_storages()
 
 
 
 
 
 
 
 
 
 
933
 
934
  # Initialize FastAPI
935
  app = FastAPI(
 
961
  allow_headers=["*"],
962
  )
963
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
964
  # Create the optional API key dependency
965
  optional_api_key = get_api_key_dependency(api_key)
966
 
 
1095
  },
1096
  log_level=args.log_level,
1097
  namespace_prefix=args.namespace_prefix,
1098
+ auto_manage_storages_states=False,
1099
  )
1100
  else:
1101
  rag = LightRAG(
 
1127
  },
1128
  log_level=args.log_level,
1129
  namespace_prefix=args.namespace_prefix,
1130
+ auto_manage_storages_states=False,
1131
  )
1132
 
 
 
 
 
 
 
 
 
 
 
 
 
1133
  async def pipeline_enqueue_file(file_path: Path) -> bool:
1134
  """Add a file to the queue for processing
1135
 
lightrag/base.py CHANGED
@@ -87,6 +87,14 @@ class StorageNameSpace(ABC):
87
  namespace: str
88
  global_config: dict[str, Any]
89
 
 
 
 
 
 
 
 
 
90
  @abstractmethod
91
  async def index_done_callback(self) -> None:
92
  """Commit the storage operations after indexing"""
@@ -247,3 +255,12 @@ class DocStatusStorage(BaseKVStorage, ABC):
247
  self, status: DocStatus
248
  ) -> dict[str, DocProcessingStatus]:
249
  """Get all documents with a specific status"""
 
 
 
 
 
 
 
 
 
 
87
  namespace: str
88
  global_config: dict[str, Any]
89
 
90
+ async def initialize(self):
91
+ """Initialize the storage"""
92
+ pass
93
+
94
+ async def finalize(self):
95
+ """Finalize the storage"""
96
+ pass
97
+
98
  @abstractmethod
99
  async def index_done_callback(self) -> None:
100
  """Commit the storage operations after indexing"""
 
255
  self, status: DocStatus
256
  ) -> dict[str, DocProcessingStatus]:
257
  """Get all documents with a specific status"""
258
+
259
+
260
+ class StoragesStatus(str, Enum):
261
+ """Storages status"""
262
+
263
+ NOT_CREATED = "not_created"
264
+ CREATED = "created"
265
+ INITIALIZED = "initialized"
266
+ FINALIZED = "finalized"
lightrag/kg/mongo_impl.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from dataclasses import dataclass
3
  import numpy as np
4
  import configparser
5
  import asyncio
@@ -26,8 +26,11 @@ if not pm.is_installed("motor"):
26
  pm.install("motor")
27
 
28
  try:
29
- from motor.motor_asyncio import AsyncIOMotorClient
30
- from pymongo import MongoClient
 
 
 
31
  from pymongo.operations import SearchIndexModel
32
  from pymongo.errors import PyMongoError
33
  except ImportError as e:
@@ -39,31 +42,63 @@ config = configparser.ConfigParser()
39
  config.read("config.ini", "utf-8")
40
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  @final
43
  @dataclass
44
  class MongoKVStorage(BaseKVStorage):
45
- def __post_init__(self):
46
- uri = os.environ.get(
47
- "MONGO_URI",
48
- config.get(
49
- "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
50
- ),
51
- )
52
- client = AsyncIOMotorClient(uri)
53
- database = client.get_database(
54
- os.environ.get(
55
- "MONGO_DATABASE",
56
- config.get("mongodb", "database", fallback="LightRAG"),
57
- )
58
- )
59
 
 
60
  self._collection_name = self.namespace
61
 
62
- self._data = database.get_collection(self._collection_name)
63
- logger.debug(f"Use MongoDB as KV {self._collection_name}")
 
 
 
64
 
65
- # Ensure collection exists
66
- create_collection_if_not_exists(uri, database.name, self._collection_name)
 
 
 
67
 
68
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
69
  return await self._data.find_one({"_id": id})
@@ -120,28 +155,23 @@ class MongoKVStorage(BaseKVStorage):
120
  @final
121
  @dataclass
122
  class MongoDocStatusStorage(DocStatusStorage):
123
- def __post_init__(self):
124
- uri = os.environ.get(
125
- "MONGO_URI",
126
- config.get(
127
- "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
128
- ),
129
- )
130
- client = AsyncIOMotorClient(uri)
131
- database = client.get_database(
132
- os.environ.get(
133
- "MONGO_DATABASE",
134
- config.get("mongodb", "database", fallback="LightRAG"),
135
- )
136
- )
137
 
 
138
  self._collection_name = self.namespace
139
- self._data = database.get_collection(self._collection_name)
140
 
141
- logger.debug(f"Use MongoDB as doc status {self._collection_name}")
 
 
 
 
142
 
143
- # Ensure collection exists
144
- create_collection_if_not_exists(uri, database.name, self._collection_name)
 
 
 
145
 
146
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
147
  return await self._data.find_one({"_id": id})
@@ -202,36 +232,33 @@ class MongoDocStatusStorage(DocStatusStorage):
202
  @dataclass
203
  class MongoGraphStorage(BaseGraphStorage):
204
  """
205
- A concrete implementation using MongoDBs $graphLookup to demonstrate multi-hop queries.
206
  """
207
 
 
 
 
208
  def __init__(self, namespace, global_config, embedding_func):
209
  super().__init__(
210
  namespace=namespace,
211
  global_config=global_config,
212
  embedding_func=embedding_func,
213
  )
214
- uri = os.environ.get(
215
- "MONGO_URI",
216
- config.get(
217
- "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
218
- ),
219
- )
220
- client = AsyncIOMotorClient(uri)
221
- database = client.get_database(
222
- os.environ.get(
223
- "MONGO_DATABASE",
224
- config.get("mongodb", "database", fallback="LightRAG"),
225
- )
226
- )
227
-
228
  self._collection_name = self.namespace
229
- self.collection = database.get_collection(self._collection_name)
230
 
231
- logger.debug(f"Use MongoDB as KG {self._collection_name}")
 
 
 
 
 
 
232
 
233
- # Ensure collection exists
234
- create_collection_if_not_exists(uri, database.name, self._collection_name)
 
 
 
235
 
236
  #
237
  # -------------------------------------------------------------------------
@@ -770,6 +797,9 @@ class MongoGraphStorage(BaseGraphStorage):
770
  @final
771
  @dataclass
772
  class MongoVectorDBStorage(BaseVectorStorage):
 
 
 
773
  def __post_init__(self):
774
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
775
  cosine_threshold = kwargs.get("cosine_better_than_threshold")
@@ -778,41 +808,36 @@ class MongoVectorDBStorage(BaseVectorStorage):
778
  "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
779
  )
780
  self.cosine_better_than_threshold = cosine_threshold
781
-
782
- uri = os.environ.get(
783
- "MONGO_URI",
784
- config.get(
785
- "mongodb", "uri", fallback="mongodb://root:root@localhost:27017/"
786
- ),
787
- )
788
- client = AsyncIOMotorClient(uri)
789
- database = client.get_database(
790
- os.environ.get(
791
- "MONGO_DATABASE",
792
- config.get("mongodb", "database", fallback="LightRAG"),
793
- )
794
- )
795
-
796
  self._collection_name = self.namespace
797
- self._data = database.get_collection(self._collection_name)
798
  self._max_batch_size = self.global_config["embedding_batch_num"]
799
 
800
- logger.debug(f"Use MongoDB as VDB {self._collection_name}")
 
 
 
801
 
802
- # Ensure collection exists
803
- create_collection_if_not_exists(uri, database.name, self._collection_name)
804
 
805
- # Ensure vector index exists
806
- self.create_vector_index(uri, database.name, self._collection_name)
807
 
808
- def create_vector_index(self, uri: str, database_name: str, collection_name: str):
809
- """Creates an Atlas Vector Search index."""
810
- client = MongoClient(uri)
811
- collection = client.get_database(database_name).get_collection(
812
- self._collection_name
813
- )
814
 
 
 
815
  try:
 
 
 
 
 
 
 
 
816
  search_index_model = SearchIndexModel(
817
  definition={
818
  "fields": [
@@ -824,11 +849,11 @@ class MongoVectorDBStorage(BaseVectorStorage):
824
  }
825
  ]
826
  },
827
- name="vector_knn_index",
828
  type="vectorSearch",
829
  )
830
 
831
- collection.create_search_index(search_index_model)
832
  logger.info("Vector index created successfully.")
833
 
834
  except PyMongoError as _:
@@ -913,15 +938,13 @@ class MongoVectorDBStorage(BaseVectorStorage):
913
  raise NotImplementedError
914
 
915
 
916
- def create_collection_if_not_exists(uri: str, database_name: str, collection_name: str):
917
- """Check if the collection exists. if not, create it."""
918
- client = MongoClient(uri)
919
- database = client.get_database(database_name)
920
-
921
- collection_names = database.list_collection_names()
922
 
923
  if collection_name not in collection_names:
924
- database.create_collection(collection_name)
925
  logger.info(f"Created collection: {collection_name}")
 
926
  else:
927
  logger.debug(f"Collection '{collection_name}' already exists.")
 
 
1
  import os
2
+ from dataclasses import dataclass, field
3
  import numpy as np
4
  import configparser
5
  import asyncio
 
26
  pm.install("motor")
27
 
28
  try:
29
+ from motor.motor_asyncio import (
30
+ AsyncIOMotorClient,
31
+ AsyncIOMotorDatabase,
32
+ AsyncIOMotorCollection,
33
+ )
34
  from pymongo.operations import SearchIndexModel
35
  from pymongo.errors import PyMongoError
36
  except ImportError as e:
 
42
  config.read("config.ini", "utf-8")
43
 
44
 
45
+ class ClientManager:
46
+ _instances = {"db": None, "ref_count": 0}
47
+ _lock = asyncio.Lock()
48
+
49
+ @classmethod
50
+ async def get_client(cls) -> AsyncIOMotorDatabase:
51
+ async with cls._lock:
52
+ if cls._instances["db"] is None:
53
+ uri = os.environ.get(
54
+ "MONGO_URI",
55
+ config.get(
56
+ "mongodb",
57
+ "uri",
58
+ fallback="mongodb://root:root@localhost:27017/",
59
+ ),
60
+ )
61
+ database_name = os.environ.get(
62
+ "MONGO_DATABASE",
63
+ config.get("mongodb", "database", fallback="LightRAG"),
64
+ )
65
+ client = AsyncIOMotorClient(uri)
66
+ db = client.get_database(database_name)
67
+ cls._instances["db"] = db
68
+ cls._instances["ref_count"] = 0
69
+ cls._instances["ref_count"] += 1
70
+ return cls._instances["db"]
71
+
72
+ @classmethod
73
+ async def release_client(cls, db: AsyncIOMotorDatabase):
74
+ async with cls._lock:
75
+ if db is not None:
76
+ if db is cls._instances["db"]:
77
+ cls._instances["ref_count"] -= 1
78
+ if cls._instances["ref_count"] == 0:
79
+ cls._instances["db"] = None
80
+
81
+
82
  @final
83
  @dataclass
84
  class MongoKVStorage(BaseKVStorage):
85
+ db: AsyncIOMotorDatabase = field(default=None)
86
+ _data: AsyncIOMotorCollection = field(default=None)
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ def __post_init__(self):
89
  self._collection_name = self.namespace
90
 
91
+ async def initialize(self):
92
+ if self.db is None:
93
+ self.db = await ClientManager.get_client()
94
+ self._data = await get_or_create_collection(self.db, self._collection_name)
95
+ logger.debug(f"Use MongoDB as KV {self._collection_name}")
96
 
97
+ async def finalize(self):
98
+ if self.db is not None:
99
+ await ClientManager.release_client(self.db)
100
+ self.db = None
101
+ self._data = None
102
 
103
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
104
  return await self._data.find_one({"_id": id})
 
155
  @final
156
  @dataclass
157
  class MongoDocStatusStorage(DocStatusStorage):
158
+ db: AsyncIOMotorDatabase = field(default=None)
159
+ _data: AsyncIOMotorCollection = field(default=None)
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ def __post_init__(self):
162
  self._collection_name = self.namespace
 
163
 
164
+ async def initialize(self):
165
+ if self.db is None:
166
+ self.db = await ClientManager.get_client()
167
+ self._data = await get_or_create_collection(self.db, self._collection_name)
168
+ logger.debug(f"Use MongoDB as DocStatus {self._collection_name}")
169
 
170
+ async def finalize(self):
171
+ if self.db is not None:
172
+ await ClientManager.release_client(self.db)
173
+ self.db = None
174
+ self._data = None
175
 
176
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
177
  return await self._data.find_one({"_id": id})
 
232
  @dataclass
233
  class MongoGraphStorage(BaseGraphStorage):
234
  """
235
+ A concrete implementation using MongoDB's $graphLookup to demonstrate multi-hop queries.
236
  """
237
 
238
+ db: AsyncIOMotorDatabase = field(default=None)
239
+ collection: AsyncIOMotorCollection = field(default=None)
240
+
241
  def __init__(self, namespace, global_config, embedding_func):
242
  super().__init__(
243
  namespace=namespace,
244
  global_config=global_config,
245
  embedding_func=embedding_func,
246
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  self._collection_name = self.namespace
 
248
 
249
+ async def initialize(self):
250
+ if self.db is None:
251
+ self.db = await ClientManager.get_client()
252
+ self.collection = await get_or_create_collection(
253
+ self.db, self._collection_name
254
+ )
255
+ logger.debug(f"Use MongoDB as KG {self._collection_name}")
256
 
257
+ async def finalize(self):
258
+ if self.db is not None:
259
+ await ClientManager.release_client(self.db)
260
+ self.db = None
261
+ self.collection = None
262
 
263
  #
264
  # -------------------------------------------------------------------------
 
797
  @final
798
  @dataclass
799
  class MongoVectorDBStorage(BaseVectorStorage):
800
+ db: AsyncIOMotorDatabase = field(default=None)
801
+ _data: AsyncIOMotorCollection = field(default=None)
802
+
803
  def __post_init__(self):
804
  kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
805
  cosine_threshold = kwargs.get("cosine_better_than_threshold")
 
808
  "cosine_better_than_threshold must be specified in vector_db_storage_cls_kwargs"
809
  )
810
  self.cosine_better_than_threshold = cosine_threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811
  self._collection_name = self.namespace
 
812
  self._max_batch_size = self.global_config["embedding_batch_num"]
813
 
814
+ async def initialize(self):
815
+ if self.db is None:
816
+ self.db = await ClientManager.get_client()
817
+ self._data = await get_or_create_collection(self.db, self._collection_name)
818
 
819
+ # Ensure vector index exists
820
+ await self.create_vector_index_if_not_exists()
821
 
822
+ logger.debug(f"Use MongoDB as VDB {self._collection_name}")
 
823
 
824
+ async def finalize(self):
825
+ if self.db is not None:
826
+ await ClientManager.release_client(self.db)
827
+ self.db = None
828
+ self._data = None
 
829
 
830
+ async def create_vector_index_if_not_exists(self):
831
+ """Creates an Atlas Vector Search index."""
832
  try:
833
+ index_name = "vector_knn_index"
834
+
835
+ indexes = await self._data.list_search_indexes().to_list(length=None)
836
+ for index in indexes:
837
+ if index["name"] == index_name:
838
+ logger.debug("vector index already exist")
839
+ return
840
+
841
  search_index_model = SearchIndexModel(
842
  definition={
843
  "fields": [
 
849
  }
850
  ]
851
  },
852
+ name=index_name,
853
  type="vectorSearch",
854
  )
855
 
856
+ await self._data.create_search_index(search_index_model)
857
  logger.info("Vector index created successfully.")
858
 
859
  except PyMongoError as _:
 
938
  raise NotImplementedError
939
 
940
 
941
+ async def get_or_create_collection(db: AsyncIOMotorDatabase, collection_name: str):
942
+ collection_names = await db.list_collection_names()
 
 
 
 
943
 
944
  if collection_name not in collection_names:
945
+ collection = await db.create_collection(collection_name)
946
  logger.info(f"Created collection: {collection_name}")
947
+ return collection
948
  else:
949
  logger.debug(f"Collection '{collection_name}' already exists.")
950
+ return db.get_collection(collection_name)
lightrag/kg/oracle_impl.py CHANGED
@@ -2,11 +2,11 @@ import array
2
  import asyncio
3
 
4
  # import html
5
- # import os
6
- from dataclasses import dataclass
7
  from typing import Any, Union, final
8
-
9
  import numpy as np
 
10
 
11
  from lightrag.types import KnowledgeGraph
12
 
@@ -177,17 +177,91 @@ class OracleDB:
177
  raise
178
 
179
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  @final
181
  @dataclass
182
  class OracleKVStorage(BaseKVStorage):
183
- # db instance must be injected before use
184
- # db: OracleDB
185
  meta_fields = None
186
 
187
  def __post_init__(self):
188
  self._data = {}
189
  self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
190
 
 
 
 
 
 
 
 
 
 
191
  ################ QUERY METHODS ################
192
 
193
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -324,6 +398,8 @@ class OracleKVStorage(BaseKVStorage):
324
  @final
325
  @dataclass
326
  class OracleVectorDBStorage(BaseVectorStorage):
 
 
327
  def __post_init__(self):
328
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
329
  cosine_threshold = config.get("cosine_better_than_threshold")
@@ -333,6 +409,15 @@ class OracleVectorDBStorage(BaseVectorStorage):
333
  )
334
  self.cosine_better_than_threshold = cosine_threshold
335
 
 
 
 
 
 
 
 
 
 
336
  #################### query method ###############
337
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
338
  embeddings = await self.embedding_func([query])
@@ -369,9 +454,20 @@ class OracleVectorDBStorage(BaseVectorStorage):
369
  @final
370
  @dataclass
371
  class OracleGraphStorage(BaseGraphStorage):
 
 
372
  def __post_init__(self):
373
  self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
374
 
 
 
 
 
 
 
 
 
 
375
  #################### insert method ################
376
 
377
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
 
2
  import asyncio
3
 
4
  # import html
5
+ import os
6
+ from dataclasses import dataclass, field
7
  from typing import Any, Union, final
 
8
  import numpy as np
9
+ import configparser
10
 
11
  from lightrag.types import KnowledgeGraph
12
 
 
177
  raise
178
 
179
 
180
+ class ClientManager:
181
+ _instances = {"db": None, "ref_count": 0}
182
+ _lock = asyncio.Lock()
183
+
184
+ @staticmethod
185
+ def get_config():
186
+ config = configparser.ConfigParser()
187
+ config.read("config.ini", "utf-8")
188
+
189
+ return {
190
+ "user": os.environ.get(
191
+ "ORACLE_USER",
192
+ config.get("oracle", "user", fallback=None),
193
+ ),
194
+ "password": os.environ.get(
195
+ "ORACLE_PASSWORD",
196
+ config.get("oracle", "password", fallback=None),
197
+ ),
198
+ "dsn": os.environ.get(
199
+ "ORACLE_DSN",
200
+ config.get("oracle", "dsn", fallback=None),
201
+ ),
202
+ "config_dir": os.environ.get(
203
+ "ORACLE_CONFIG_DIR",
204
+ config.get("oracle", "config_dir", fallback=None),
205
+ ),
206
+ "wallet_location": os.environ.get(
207
+ "ORACLE_WALLET_LOCATION",
208
+ config.get("oracle", "wallet_location", fallback=None),
209
+ ),
210
+ "wallet_password": os.environ.get(
211
+ "ORACLE_WALLET_PASSWORD",
212
+ config.get("oracle", "wallet_password", fallback=None),
213
+ ),
214
+ "workspace": os.environ.get(
215
+ "ORACLE_WORKSPACE",
216
+ config.get("oracle", "workspace", fallback="default"),
217
+ ),
218
+ }
219
+
220
+ @classmethod
221
+ async def get_client(cls) -> OracleDB:
222
+ async with cls._lock:
223
+ if cls._instances["db"] is None:
224
+ config = ClientManager.get_config()
225
+ db = OracleDB(config)
226
+ await db.check_tables()
227
+ cls._instances["db"] = db
228
+ cls._instances["ref_count"] = 0
229
+ cls._instances["ref_count"] += 1
230
+ return cls._instances["db"]
231
+
232
+ @classmethod
233
+ async def release_client(cls, db: OracleDB):
234
+ async with cls._lock:
235
+ if db is not None:
236
+ if db is cls._instances["db"]:
237
+ cls._instances["ref_count"] -= 1
238
+ if cls._instances["ref_count"] == 0:
239
+ await db.pool.close()
240
+ logger.info("Closed OracleDB database connection pool")
241
+ cls._instances["db"] = None
242
+ else:
243
+ await db.pool.close()
244
+
245
+
246
  @final
247
  @dataclass
248
  class OracleKVStorage(BaseKVStorage):
249
+ db: OracleDB = field(default=None)
 
250
  meta_fields = None
251
 
252
  def __post_init__(self):
253
  self._data = {}
254
  self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
255
 
256
+ async def initialize(self):
257
+ if self.db is None:
258
+ self.db = await ClientManager.get_client()
259
+
260
+ async def finalize(self):
261
+ if self.db is not None:
262
+ await ClientManager.release_client(self.db)
263
+ self.db = None
264
+
265
  ################ QUERY METHODS ################
266
 
267
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
 
398
  @final
399
  @dataclass
400
  class OracleVectorDBStorage(BaseVectorStorage):
401
+ db: OracleDB = field(default=None)
402
+
403
  def __post_init__(self):
404
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
405
  cosine_threshold = config.get("cosine_better_than_threshold")
 
409
  )
410
  self.cosine_better_than_threshold = cosine_threshold
411
 
412
+ async def initialize(self):
413
+ if self.db is None:
414
+ self.db = await ClientManager.get_client()
415
+
416
+ async def finalize(self):
417
+ if self.db is not None:
418
+ await ClientManager.release_client(self.db)
419
+ self.db = None
420
+
421
  #################### query method ###############
422
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
423
  embeddings = await self.embedding_func([query])
 
454
  @final
455
  @dataclass
456
  class OracleGraphStorage(BaseGraphStorage):
457
+ db: OracleDB = field(default=None)
458
+
459
  def __post_init__(self):
460
  self._max_batch_size = self.global_config.get("embedding_batch_num", 10)
461
 
462
+ async def initialize(self):
463
+ if self.db is None:
464
+ self.db = await ClientManager.get_client()
465
+
466
+ async def finalize(self):
467
+ if self.db is not None:
468
+ await ClientManager.release_client(self.db)
469
+ self.db = None
470
+
471
  #################### insert method ################
472
 
473
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
lightrag/kg/postgres_impl.py CHANGED
@@ -3,10 +3,10 @@ import inspect
3
  import json
4
  import os
5
  import time
6
- from dataclasses import dataclass
7
  from typing import Any, Dict, List, Union, final
8
-
9
  import numpy as np
 
10
 
11
  from lightrag.types import KnowledgeGraph
12
 
@@ -181,15 +181,84 @@ class PostgreSQLDB:
181
  pass
182
 
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  @final
185
  @dataclass
186
  class PGKVStorage(BaseKVStorage):
187
- # db instance must be injected before use
188
- # db: PostgreSQLDB
189
 
190
  def __post_init__(self):
191
  self._max_batch_size = self.global_config["embedding_batch_num"]
192
 
 
 
 
 
 
 
 
 
 
193
  ################ QUERY METHODS ################
194
 
195
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -308,6 +377,8 @@ class PGKVStorage(BaseKVStorage):
308
  @final
309
  @dataclass
310
  class PGVectorStorage(BaseVectorStorage):
 
 
311
  def __post_init__(self):
312
  self._max_batch_size = self.global_config["embedding_batch_num"]
313
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
@@ -318,6 +389,15 @@ class PGVectorStorage(BaseVectorStorage):
318
  )
319
  self.cosine_better_than_threshold = cosine_threshold
320
 
 
 
 
 
 
 
 
 
 
321
  def _upsert_chunks(self, item: dict):
322
  try:
323
  upsert_sql = SQL_TEMPLATES["upsert_chunk"]
@@ -426,6 +506,17 @@ class PGVectorStorage(BaseVectorStorage):
426
  @final
427
  @dataclass
428
  class PGDocStatusStorage(DocStatusStorage):
 
 
 
 
 
 
 
 
 
 
 
429
  async def filter_keys(self, keys: set[str]) -> set[str]:
430
  """Filter out duplicated content"""
431
  sql = SQL_TEMPLATES["filter_keys"].format(
@@ -565,6 +656,8 @@ class PGGraphQueryException(Exception):
565
  @final
566
  @dataclass
567
  class PGGraphStorage(BaseGraphStorage):
 
 
568
  @staticmethod
569
  def load_nx_graph(file_name):
570
  print("no preloading of graph with AGE in production")
@@ -575,6 +668,15 @@ class PGGraphStorage(BaseGraphStorage):
575
  "node2vec": self._node2vec_embed,
576
  }
577
 
 
 
 
 
 
 
 
 
 
578
  async def index_done_callback(self) -> None:
579
  # PG handles persistence automatically
580
  pass
 
3
  import json
4
  import os
5
  import time
6
+ from dataclasses import dataclass, field
7
  from typing import Any, Dict, List, Union, final
 
8
  import numpy as np
9
+ import configparser
10
 
11
  from lightrag.types import KnowledgeGraph
12
 
 
181
  pass
182
 
183
 
184
+ class ClientManager:
185
+ _instances = {"db": None, "ref_count": 0}
186
+ _lock = asyncio.Lock()
187
+
188
+ @staticmethod
189
+ def get_config():
190
+ config = configparser.ConfigParser()
191
+ config.read("config.ini", "utf-8")
192
+
193
+ return {
194
+ "host": os.environ.get(
195
+ "POSTGRES_HOST",
196
+ config.get("postgres", "host", fallback="localhost"),
197
+ ),
198
+ "port": os.environ.get(
199
+ "POSTGRES_PORT", config.get("postgres", "port", fallback=5432)
200
+ ),
201
+ "user": os.environ.get(
202
+ "POSTGRES_USER", config.get("postgres", "user", fallback=None)
203
+ ),
204
+ "password": os.environ.get(
205
+ "POSTGRES_PASSWORD",
206
+ config.get("postgres", "password", fallback=None),
207
+ ),
208
+ "database": os.environ.get(
209
+ "POSTGRES_DATABASE",
210
+ config.get("postgres", "database", fallback=None),
211
+ ),
212
+ "workspace": os.environ.get(
213
+ "POSTGRES_WORKSPACE",
214
+ config.get("postgres", "workspace", fallback="default"),
215
+ ),
216
+ }
217
+
218
+ @classmethod
219
+ async def get_client(cls) -> PostgreSQLDB:
220
+ async with cls._lock:
221
+ if cls._instances["db"] is None:
222
+ config = ClientManager.get_config()
223
+ db = PostgreSQLDB(config)
224
+ await db.initdb()
225
+ await db.check_tables()
226
+ cls._instances["db"] = db
227
+ cls._instances["ref_count"] = 0
228
+ cls._instances["ref_count"] += 1
229
+ return cls._instances["db"]
230
+
231
+ @classmethod
232
+ async def release_client(cls, db: PostgreSQLDB):
233
+ async with cls._lock:
234
+ if db is not None:
235
+ if db is cls._instances["db"]:
236
+ cls._instances["ref_count"] -= 1
237
+ if cls._instances["ref_count"] == 0:
238
+ await db.pool.close()
239
+ logger.info("Closed PostgreSQL database connection pool")
240
+ cls._instances["db"] = None
241
+ else:
242
+ await db.pool.close()
243
+
244
+
245
  @final
246
  @dataclass
247
  class PGKVStorage(BaseKVStorage):
248
+ db: PostgreSQLDB = field(default=None)
 
249
 
250
  def __post_init__(self):
251
  self._max_batch_size = self.global_config["embedding_batch_num"]
252
 
253
+ async def initialize(self):
254
+ if self.db is None:
255
+ self.db = await ClientManager.get_client()
256
+
257
+ async def finalize(self):
258
+ if self.db is not None:
259
+ await ClientManager.release_client(self.db)
260
+ self.db = None
261
+
262
  ################ QUERY METHODS ################
263
 
264
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
 
377
  @final
378
  @dataclass
379
  class PGVectorStorage(BaseVectorStorage):
380
+ db: PostgreSQLDB = field(default=None)
381
+
382
  def __post_init__(self):
383
  self._max_batch_size = self.global_config["embedding_batch_num"]
384
  config = self.global_config.get("vector_db_storage_cls_kwargs", {})
 
389
  )
390
  self.cosine_better_than_threshold = cosine_threshold
391
 
392
+ async def initialize(self):
393
+ if self.db is None:
394
+ self.db = await ClientManager.get_client()
395
+
396
+ async def finalize(self):
397
+ if self.db is not None:
398
+ await ClientManager.release_client(self.db)
399
+ self.db = None
400
+
401
  def _upsert_chunks(self, item: dict):
402
  try:
403
  upsert_sql = SQL_TEMPLATES["upsert_chunk"]
 
506
  @final
507
  @dataclass
508
  class PGDocStatusStorage(DocStatusStorage):
509
+ db: PostgreSQLDB = field(default=None)
510
+
511
+ async def initialize(self):
512
+ if self.db is None:
513
+ self.db = await ClientManager.get_client()
514
+
515
+ async def finalize(self):
516
+ if self.db is not None:
517
+ await ClientManager.release_client(self.db)
518
+ self.db = None
519
+
520
  async def filter_keys(self, keys: set[str]) -> set[str]:
521
  """Filter out duplicated content"""
522
  sql = SQL_TEMPLATES["filter_keys"].format(
 
656
  @final
657
  @dataclass
658
  class PGGraphStorage(BaseGraphStorage):
659
+ db: PostgreSQLDB = field(default=None)
660
+
661
  @staticmethod
662
  def load_nx_graph(file_name):
663
  print("no preloading of graph with AGE in production")
 
668
  "node2vec": self._node2vec_embed,
669
  }
670
 
671
+ async def initialize(self):
672
+ if self.db is None:
673
+ self.db = await ClientManager.get_client()
674
+
675
+ async def finalize(self):
676
+ if self.db is not None:
677
+ await ClientManager.release_client(self.db)
678
+ self.db = None
679
+
680
  async def index_done_callback(self) -> None:
681
  # PG handles persistence automatically
682
  pass
lightrag/kg/tidb_impl.py CHANGED
@@ -1,6 +1,6 @@
1
  import asyncio
2
  import os
3
- from dataclasses import dataclass
4
  from typing import Any, Union, final
5
 
6
  import numpy as np
@@ -13,6 +13,7 @@ from ..namespace import NameSpace, is_namespace
13
  from ..utils import logger
14
 
15
  import pipmaster as pm
 
16
 
17
  if not pm.is_installed("pymysql"):
18
  pm.install("pymysql")
@@ -104,16 +105,81 @@ class TiDB:
104
  raise
105
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  @final
108
  @dataclass
109
  class TiDBKVStorage(BaseKVStorage):
110
- # db instance must be injected before use
111
- # db: TiDB
112
 
113
  def __post_init__(self):
114
  self._data = {}
115
  self._max_batch_size = self.global_config["embedding_batch_num"]
116
 
 
 
 
 
 
 
 
 
 
117
  ################ QUERY METHODS ################
118
 
119
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
@@ -184,7 +250,7 @@ class TiDBKVStorage(BaseKVStorage):
184
  "tokens": item["tokens"],
185
  "chunk_order_index": item["chunk_order_index"],
186
  "full_doc_id": item["full_doc_id"],
187
- "content_vector": f'{item["__vector__"].tolist()}',
188
  "workspace": self.db.workspace,
189
  }
190
  )
@@ -212,6 +278,8 @@ class TiDBKVStorage(BaseKVStorage):
212
  @final
213
  @dataclass
214
  class TiDBVectorDBStorage(BaseVectorStorage):
 
 
215
  def __post_init__(self):
216
  self._client_file_name = os.path.join(
217
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
@@ -225,6 +293,15 @@ class TiDBVectorDBStorage(BaseVectorStorage):
225
  )
226
  self.cosine_better_than_threshold = cosine_threshold
227
 
 
 
 
 
 
 
 
 
 
228
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
229
  """Search from tidb vector"""
230
  embeddings = await self.embedding_func([query])
@@ -282,7 +359,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
282
  "id": item["id"],
283
  "name": item["entity_name"],
284
  "content": item["content"],
285
- "content_vector": f'{item["content_vector"].tolist()}',
286
  "workspace": self.db.workspace,
287
  }
288
  # update entity_id if node inserted by graph_storage_instance before
@@ -304,7 +381,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
304
  "source_name": item["src_id"],
305
  "target_name": item["tgt_id"],
306
  "content": item["content"],
307
- "content_vector": f'{item["content_vector"].tolist()}',
308
  "workspace": self.db.workspace,
309
  }
310
  # update relation_id if node inserted by graph_storage_instance before
@@ -337,12 +414,20 @@ class TiDBVectorDBStorage(BaseVectorStorage):
337
  @final
338
  @dataclass
339
  class TiDBGraphStorage(BaseGraphStorage):
340
- # db instance must be injected before use
341
- # db: TiDB
342
 
343
  def __post_init__(self):
344
  self._max_batch_size = self.global_config["embedding_batch_num"]
345
 
 
 
 
 
 
 
 
 
 
346
  #################### upsert method ################
347
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
348
  entity_name = node_id
 
1
  import asyncio
2
  import os
3
+ from dataclasses import dataclass, field
4
  from typing import Any, Union, final
5
 
6
  import numpy as np
 
13
  from ..utils import logger
14
 
15
  import pipmaster as pm
16
+ import configparser
17
 
18
  if not pm.is_installed("pymysql"):
19
  pm.install("pymysql")
 
105
  raise
106
 
107
 
108
+ class ClientManager:
109
+ _instances = {"db": None, "ref_count": 0}
110
+ _lock = asyncio.Lock()
111
+
112
+ @staticmethod
113
+ def get_config():
114
+ config = configparser.ConfigParser()
115
+ config.read("config.ini", "utf-8")
116
+
117
+ return {
118
+ "host": os.environ.get(
119
+ "TIDB_HOST",
120
+ config.get("tidb", "host", fallback="localhost"),
121
+ ),
122
+ "port": os.environ.get(
123
+ "TIDB_PORT", config.get("tidb", "port", fallback=4000)
124
+ ),
125
+ "user": os.environ.get(
126
+ "TIDB_USER",
127
+ config.get("tidb", "user", fallback=None),
128
+ ),
129
+ "password": os.environ.get(
130
+ "TIDB_PASSWORD",
131
+ config.get("tidb", "password", fallback=None),
132
+ ),
133
+ "database": os.environ.get(
134
+ "TIDB_DATABASE",
135
+ config.get("tidb", "database", fallback=None),
136
+ ),
137
+ "workspace": os.environ.get(
138
+ "TIDB_WORKSPACE",
139
+ config.get("tidb", "workspace", fallback="default"),
140
+ ),
141
+ }
142
+
143
+ @classmethod
144
+ async def get_client(cls) -> TiDB:
145
+ async with cls._lock:
146
+ if cls._instances["db"] is None:
147
+ config = ClientManager.get_config()
148
+ db = TiDB(config)
149
+ await db.check_tables()
150
+ cls._instances["db"] = db
151
+ cls._instances["ref_count"] = 0
152
+ cls._instances["ref_count"] += 1
153
+ return cls._instances["db"]
154
+
155
+ @classmethod
156
+ async def release_client(cls, db: TiDB):
157
+ async with cls._lock:
158
+ if db is not None:
159
+ if db is cls._instances["db"]:
160
+ cls._instances["ref_count"] -= 1
161
+ if cls._instances["ref_count"] == 0:
162
+ cls._instances["db"] = None
163
+
164
+
165
  @final
166
  @dataclass
167
  class TiDBKVStorage(BaseKVStorage):
168
+ db: TiDB = field(default=None)
 
169
 
170
  def __post_init__(self):
171
  self._data = {}
172
  self._max_batch_size = self.global_config["embedding_batch_num"]
173
 
174
+ async def initialize(self):
175
+ if self.db is None:
176
+ self.db = await ClientManager.get_client()
177
+
178
+ async def finalize(self):
179
+ if self.db is not None:
180
+ await ClientManager.release_client(self.db)
181
+ self.db = None
182
+
183
  ################ QUERY METHODS ################
184
 
185
  async def get_by_id(self, id: str) -> dict[str, Any] | None:
 
250
  "tokens": item["tokens"],
251
  "chunk_order_index": item["chunk_order_index"],
252
  "full_doc_id": item["full_doc_id"],
253
+ "content_vector": f"{item['__vector__'].tolist()}",
254
  "workspace": self.db.workspace,
255
  }
256
  )
 
278
  @final
279
  @dataclass
280
  class TiDBVectorDBStorage(BaseVectorStorage):
281
+ db: TiDB = field(default=None)
282
+
283
  def __post_init__(self):
284
  self._client_file_name = os.path.join(
285
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
 
293
  )
294
  self.cosine_better_than_threshold = cosine_threshold
295
 
296
+ async def initialize(self):
297
+ if self.db is None:
298
+ self.db = await ClientManager.get_client()
299
+
300
+ async def finalize(self):
301
+ if self.db is not None:
302
+ await ClientManager.release_client(self.db)
303
+ self.db = None
304
+
305
  async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
306
  """Search from tidb vector"""
307
  embeddings = await self.embedding_func([query])
 
359
  "id": item["id"],
360
  "name": item["entity_name"],
361
  "content": item["content"],
362
+ "content_vector": f"{item['content_vector'].tolist()}",
363
  "workspace": self.db.workspace,
364
  }
365
  # update entity_id if node inserted by graph_storage_instance before
 
381
  "source_name": item["src_id"],
382
  "target_name": item["tgt_id"],
383
  "content": item["content"],
384
+ "content_vector": f"{item['content_vector'].tolist()}",
385
  "workspace": self.db.workspace,
386
  }
387
  # update relation_id if node inserted by graph_storage_instance before
 
414
  @final
415
  @dataclass
416
  class TiDBGraphStorage(BaseGraphStorage):
417
+ db: TiDB = field(default=None)
 
418
 
419
  def __post_init__(self):
420
  self._max_batch_size = self.global_config["embedding_batch_num"]
421
 
422
+ async def initialize(self):
423
+ if self.db is None:
424
+ self.db = await ClientManager.get_client()
425
+
426
+ async def finalize(self):
427
+ if self.db is not None:
428
+ await ClientManager.release_client(self.db)
429
+ self.db = None
430
+
431
  #################### upsert method ################
432
  async def upsert_node(self, node_id: str, node_data: dict[str, str]) -> None:
433
  entity_name = node_id
lightrag/lightrag.py CHANGED
@@ -17,6 +17,7 @@ from .base import (
17
  DocStatusStorage,
18
  QueryParam,
19
  StorageNameSpace,
 
20
  )
21
  from .namespace import NameSpace, make_namespace
22
  from .operate import (
@@ -348,6 +349,10 @@ class LightRAG:
348
  # Extensions
349
  addon_params: dict[str, Any] = field(default_factory=dict)
350
 
 
 
 
 
351
  """Dictionary for additional parameters and extensions."""
352
  convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
353
  convert_response_to_json
@@ -440,7 +445,10 @@ class LightRAG:
440
  **self.vector_db_storage_cls_kwargs,
441
  }
442
 
443
- # show config
 
 
 
444
  global_config = asdict(self)
445
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
446
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
@@ -547,6 +555,65 @@ class LightRAG:
547
  )
548
  )
549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  async def get_graph_labels(self):
551
  text = await self.chunk_entity_relation_graph.get_all_labels()
552
  return text
 
17
  DocStatusStorage,
18
  QueryParam,
19
  StorageNameSpace,
20
+ StoragesStatus,
21
  )
22
  from .namespace import NameSpace, make_namespace
23
  from .operate import (
 
349
  # Extensions
350
  addon_params: dict[str, Any] = field(default_factory=dict)
351
 
352
+ # Storages Management
353
+ auto_manage_storages_states: bool = True
354
+ """If True, lightrag will automatically calls initialize_storages and finalize_storages at the appropriate times."""
355
+
356
  """Dictionary for additional parameters and extensions."""
357
  convert_response_to_json_func: Callable[[str], dict[str, Any]] = (
358
  convert_response_to_json
 
445
  **self.vector_db_storage_cls_kwargs,
446
  }
447
 
448
+ # Life cycle
449
+ self.storages_status = StoragesStatus.NOT_CREATED
450
+
451
+ # Show config
452
  global_config = asdict(self)
453
  _print_config = ",\n ".join([f"{k} = {v}" for k, v in global_config.items()])
454
  logger.debug(f"LightRAG init with param:\n {_print_config}\n")
 
555
  )
556
  )
557
 
558
+ self.storages_status = StoragesStatus.CREATED
559
+
560
+ # Initialize storages
561
+ if self.auto_manage_storages_states:
562
+ loop = always_get_an_event_loop()
563
+ loop.run_until_complete(self.initialize_storages())
564
+
565
+ def __del__(self):
566
+ # Finalize storages
567
+ if self.auto_manage_storages_states:
568
+ loop = always_get_an_event_loop()
569
+ loop.run_until_complete(self.finalize_storages())
570
+
571
+ async def initialize_storages(self):
572
+ """Asynchronously initialize the storages"""
573
+ if self.storages_status == StoragesStatus.CREATED:
574
+ tasks = []
575
+
576
+ for storage in (
577
+ self.full_docs,
578
+ self.text_chunks,
579
+ self.entities_vdb,
580
+ self.relationships_vdb,
581
+ self.chunks_vdb,
582
+ self.chunk_entity_relation_graph,
583
+ self.llm_response_cache,
584
+ self.doc_status,
585
+ ):
586
+ if storage:
587
+ tasks.append(storage.initialize())
588
+
589
+ await asyncio.gather(*tasks)
590
+
591
+ self.storages_status = StoragesStatus.INITIALIZED
592
+ logger.debug("Initialized Storages")
593
+
594
+ async def finalize_storages(self):
595
+ """Asynchronously finalize the storages"""
596
+ if self.storages_status == StoragesStatus.INITIALIZED:
597
+ tasks = []
598
+
599
+ for storage in (
600
+ self.full_docs,
601
+ self.text_chunks,
602
+ self.entities_vdb,
603
+ self.relationships_vdb,
604
+ self.chunks_vdb,
605
+ self.chunk_entity_relation_graph,
606
+ self.llm_response_cache,
607
+ self.doc_status,
608
+ ):
609
+ if storage:
610
+ tasks.append(storage.finalize())
611
+
612
+ await asyncio.gather(*tasks)
613
+
614
+ self.storages_status = StoragesStatus.FINALIZED
615
+ logger.debug("Finalized Storages")
616
+
617
  async def get_graph_labels(self):
618
  text = await self.chunk_entity_relation_graph.get_all_labels()
619
  return text