ArnoChen commited on
Commit
056dbb4
·
1 Parent(s): 1e09d54

better handling of namespace

Browse files
examples/copy_llm_cache_to_another_storage.py CHANGED
@@ -11,6 +11,7 @@ from dotenv import load_dotenv
11
 
12
  from lightrag.kg.postgres_impl import PostgreSQLDB, PGKVStorage
13
  from lightrag.storage import JsonKVStorage
 
14
 
15
  load_dotenv()
16
  ROOT_DIR = os.environ.get("ROOT_DIR")
@@ -39,14 +40,14 @@ async def copy_from_postgres_to_json():
39
  await postgres_db.initdb()
40
 
41
  from_llm_response_cache = PGKVStorage(
42
- namespace="llm_response_cache",
43
  global_config={"embedding_batch_num": 6},
44
  embedding_func=None,
45
  db=postgres_db,
46
  )
47
 
48
  to_llm_response_cache = JsonKVStorage(
49
- namespace="llm_response_cache",
50
  global_config={"working_dir": WORKING_DIR},
51
  embedding_func=None,
52
  )
@@ -72,13 +73,13 @@ async def copy_from_json_to_postgres():
72
  await postgres_db.initdb()
73
 
74
  from_llm_response_cache = JsonKVStorage(
75
- namespace="llm_response_cache",
76
  global_config={"working_dir": WORKING_DIR},
77
  embedding_func=None,
78
  )
79
 
80
  to_llm_response_cache = PGKVStorage(
81
- namespace="llm_response_cache",
82
  global_config={"embedding_batch_num": 6},
83
  embedding_func=None,
84
  db=postgres_db,
 
11
 
12
  from lightrag.kg.postgres_impl import PostgreSQLDB, PGKVStorage
13
  from lightrag.storage import JsonKVStorage
14
+ from lightrag.namespace import NameSpace
15
 
16
  load_dotenv()
17
  ROOT_DIR = os.environ.get("ROOT_DIR")
 
40
  await postgres_db.initdb()
41
 
42
  from_llm_response_cache = PGKVStorage(
43
+ namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE,
44
  global_config={"embedding_batch_num": 6},
45
  embedding_func=None,
46
  db=postgres_db,
47
  )
48
 
49
  to_llm_response_cache = JsonKVStorage(
50
+ namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE,
51
  global_config={"working_dir": WORKING_DIR},
52
  embedding_func=None,
53
  )
 
73
  await postgres_db.initdb()
74
 
75
  from_llm_response_cache = JsonKVStorage(
76
+ namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE,
77
  global_config={"working_dir": WORKING_DIR},
78
  embedding_func=None,
79
  )
80
 
81
  to_llm_response_cache = PGKVStorage(
82
+ namespace=NameSpace.KV_STORE_LLM_RESPONSE_CACHE,
83
  global_config={"embedding_batch_num": 6},
84
  embedding_func=None,
85
  db=postgres_db,
lightrag/kg/mongo_impl.py CHANGED
@@ -13,10 +13,10 @@ if not pm.is_installed("motor"):
13
  from pymongo import MongoClient
14
  from motor.motor_asyncio import AsyncIOMotorClient
15
  from typing import Union, List, Tuple
16
- from lightrag.utils import logger
17
 
18
- from lightrag.base import BaseKVStorage
19
- from lightrag.base import BaseGraphStorage
 
20
 
21
 
22
  @dataclass
@@ -52,7 +52,7 @@ class MongoKVStorage(BaseKVStorage):
52
  return set([s for s in data if s not in existing_ids])
53
 
54
  async def upsert(self, data: dict[str, dict]):
55
- if self.namespace.endswith("llm_response_cache"):
56
  for mode, items in data.items():
57
  for k, v in tqdm_async(items.items(), desc="Upserting"):
58
  key = f"{mode}_{k}"
@@ -69,7 +69,7 @@ class MongoKVStorage(BaseKVStorage):
69
  return data
70
 
71
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
72
- if self.namespace.endswith("llm_response_cache"):
73
  res = {}
74
  v = self._data.find_one({"_id": mode + "_" + id})
75
  if v:
 
13
  from pymongo import MongoClient
14
  from motor.motor_asyncio import AsyncIOMotorClient
15
  from typing import Union, List, Tuple
 
16
 
17
+ from ..utils import logger
18
+ from ..base import BaseKVStorage, BaseGraphStorage
19
+ from ..namespace import NameSpace, is_namespace
20
 
21
 
22
  @dataclass
 
52
  return set([s for s in data if s not in existing_ids])
53
 
54
  async def upsert(self, data: dict[str, dict]):
55
+ if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
56
  for mode, items in data.items():
57
  for k, v in tqdm_async(items.items(), desc="Upserting"):
58
  key = f"{mode}_{k}"
 
69
  return data
70
 
71
  async def get_by_mode_and_id(self, mode: str, id: str) -> Union[dict, None]:
72
+ if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
73
  res = {}
74
  v = self._data.find_one({"_id": mode + "_" + id})
75
  if v:
lightrag/kg/oracle_impl.py CHANGED
@@ -19,6 +19,7 @@ from ..base import (
19
  BaseKVStorage,
20
  BaseVectorStorage,
21
  )
 
22
 
23
  import oracledb
24
 
@@ -185,7 +186,7 @@ class OracleKVStorage(BaseKVStorage):
185
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
186
  params = {"workspace": self.db.workspace, "id": id}
187
  # print("get_by_id:"+SQL)
188
- if self.namespace.endswith("llm_response_cache"):
189
  array_res = await self.db.query(SQL, params, multirows=True)
190
  res = {}
191
  for row in array_res:
@@ -201,7 +202,7 @@ class OracleKVStorage(BaseKVStorage):
201
  """Specifically for llm_response_cache."""
202
  SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
203
  params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
204
- if self.namespace.endswith("llm_response_cache"):
205
  array_res = await self.db.query(SQL, params, multirows=True)
206
  res = {}
207
  for row in array_res:
@@ -218,7 +219,7 @@ class OracleKVStorage(BaseKVStorage):
218
  params = {"workspace": self.db.workspace}
219
  # print("get_by_ids:"+SQL)
220
  res = await self.db.query(SQL, params, multirows=True)
221
- if self.namespace.endswith("llm_response_cache"):
222
  modes = set()
223
  dict_res: dict[str, dict] = {}
224
  for row in res:
@@ -256,7 +257,7 @@ class OracleKVStorage(BaseKVStorage):
256
  async def filter_keys(self, keys: list[str]) -> set[str]:
257
  """Return keys that don't exist in storage"""
258
  SQL = SQL_TEMPLATES["filter_keys"].format(
259
- table_name=N_T[self.namespace], ids=",".join([f"'{id}'" for id in keys])
260
  )
261
  params = {"workspace": self.db.workspace}
262
  res = await self.db.query(SQL, params, multirows=True)
@@ -269,7 +270,7 @@ class OracleKVStorage(BaseKVStorage):
269
 
270
  ################ INSERT METHODS ################
271
  async def upsert(self, data: dict[str, dict]):
272
- if self.namespace.endswith("text_chunks"):
273
  list_data = [
274
  {
275
  "id": k,
@@ -302,7 +303,7 @@ class OracleKVStorage(BaseKVStorage):
302
  "status": item["status"],
303
  }
304
  await self.db.execute(merge_sql, _data)
305
- if self.namespace.endswith("full_docs"):
306
  for k, v in data.items():
307
  # values.clear()
308
  merge_sql = SQL_TEMPLATES["merge_doc_full"]
@@ -313,7 +314,7 @@ class OracleKVStorage(BaseKVStorage):
313
  }
314
  await self.db.execute(merge_sql, _data)
315
 
316
- if self.namespace.endswith("llm_response_cache"):
317
  for mode, items in data.items():
318
  for k, v in items.items():
319
  upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
@@ -329,15 +330,16 @@ class OracleKVStorage(BaseKVStorage):
329
  return None
330
 
331
  async def change_status(self, id: str, status: str):
332
- SQL = SQL_TEMPLATES["change_status"].format(table_name=N_T[self.namespace])
333
  params = {"workspace": self.db.workspace, "id": id, "status": status}
334
  await self.db.execute(SQL, params)
335
 
336
  async def index_done_callback(self):
337
- for n in ("full_docs", "text_chunks"):
338
- if self.namespace.endswith(n):
339
- logger.info("full doc and chunk data had been saved into oracle db!")
340
- break
 
341
 
342
 
343
  @dataclass
@@ -614,13 +616,19 @@ class OracleGraphStorage(BaseGraphStorage):
614
 
615
 
616
  N_T = {
617
- "full_docs": "LIGHTRAG_DOC_FULL",
618
- "text_chunks": "LIGHTRAG_DOC_CHUNKS",
619
- "chunks": "LIGHTRAG_DOC_CHUNKS",
620
- "entities": "LIGHTRAG_GRAPH_NODES",
621
- "relationships": "LIGHTRAG_GRAPH_EDGES",
622
  }
623
 
 
 
 
 
 
 
624
  TABLES = {
625
  "LIGHTRAG_DOC_FULL": {
626
  "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
 
19
  BaseKVStorage,
20
  BaseVectorStorage,
21
  )
22
+ from ..namespace import NameSpace, is_namespace
23
 
24
  import oracledb
25
 
 
186
  SQL = SQL_TEMPLATES["get_by_id_" + self.namespace]
187
  params = {"workspace": self.db.workspace, "id": id}
188
  # print("get_by_id:"+SQL)
189
+ if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
190
  array_res = await self.db.query(SQL, params, multirows=True)
191
  res = {}
192
  for row in array_res:
 
202
  """Specifically for llm_response_cache."""
203
  SQL = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
204
  params = {"workspace": self.db.workspace, "cache_mode": mode, "id": id}
205
+ if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
206
  array_res = await self.db.query(SQL, params, multirows=True)
207
  res = {}
208
  for row in array_res:
 
219
  params = {"workspace": self.db.workspace}
220
  # print("get_by_ids:"+SQL)
221
  res = await self.db.query(SQL, params, multirows=True)
222
+ if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
223
  modes = set()
224
  dict_res: dict[str, dict] = {}
225
  for row in res:
 
257
  async def filter_keys(self, keys: list[str]) -> set[str]:
258
  """Return keys that don't exist in storage"""
259
  SQL = SQL_TEMPLATES["filter_keys"].format(
260
+ table_name=namespace_to_table_name(self.namespace), ids=",".join([f"'{id}'" for id in keys])
261
  )
262
  params = {"workspace": self.db.workspace}
263
  res = await self.db.query(SQL, params, multirows=True)
 
270
 
271
  ################ INSERT METHODS ################
272
  async def upsert(self, data: dict[str, dict]):
273
+ if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
274
  list_data = [
275
  {
276
  "id": k,
 
303
  "status": item["status"],
304
  }
305
  await self.db.execute(merge_sql, _data)
306
+ if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
307
  for k, v in data.items():
308
  # values.clear()
309
  merge_sql = SQL_TEMPLATES["merge_doc_full"]
 
314
  }
315
  await self.db.execute(merge_sql, _data)
316
 
317
+ if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
318
  for mode, items in data.items():
319
  for k, v in items.items():
320
  upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
 
330
  return None
331
 
332
  async def change_status(self, id: str, status: str):
333
+ SQL = SQL_TEMPLATES["change_status"].format(table_name=namespace_to_table_name(self.namespace))
334
  params = {"workspace": self.db.workspace, "id": id, "status": status}
335
  await self.db.execute(SQL, params)
336
 
337
  async def index_done_callback(self):
338
+ if is_namespace(
339
+ self.namespace,
340
+ (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
341
+ ):
342
+ logger.info("full doc and chunk data had been saved into oracle db!")
343
 
344
 
345
  @dataclass
 
616
 
617
 
618
  N_T = {
619
+ NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
620
+ NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
621
+ NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
622
+ NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES",
623
+ NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
624
  }
625
 
626
+ def namespace_to_table_name(namespace: str) -> str:
627
+ for k, v in N_T.items():
628
+ if is_namespace(namespace, k):
629
+ return v
630
+
631
+
632
  TABLES = {
633
  "LIGHTRAG_DOC_FULL": {
634
  "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
lightrag/kg/postgres_impl.py CHANGED
@@ -32,6 +32,7 @@ from ..base import (
32
  BaseGraphStorage,
33
  T,
34
  )
 
35
 
36
  if sys.platform.startswith("win"):
37
  import asyncio.windows_events
@@ -187,7 +188,7 @@ class PGKVStorage(BaseKVStorage):
187
  """Get doc_full data by id."""
188
  sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
189
  params = {"workspace": self.db.workspace, "id": id}
190
- if self.namespace.endswith("llm_response_cache"):
191
  array_res = await self.db.query(sql, params, multirows=True)
192
  res = {}
193
  for row in array_res:
@@ -203,7 +204,7 @@ class PGKVStorage(BaseKVStorage):
203
  """Specifically for llm_response_cache."""
204
  sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
205
  params = {"workspace": self.db.workspace, mode: mode, "id": id}
206
- if self.namespace.endswith("llm_response_cache"):
207
  array_res = await self.db.query(sql, params, multirows=True)
208
  res = {}
209
  for row in array_res:
@@ -219,7 +220,7 @@ class PGKVStorage(BaseKVStorage):
219
  ids=",".join([f"'{id}'" for id in ids])
220
  )
221
  params = {"workspace": self.db.workspace}
222
- if self.namespace.endswith("llm_response_cache"):
223
  array_res = await self.db.query(sql, params, multirows=True)
224
  modes = set()
225
  dict_res: dict[str, dict] = {}
@@ -239,7 +240,7 @@ class PGKVStorage(BaseKVStorage):
239
  return None
240
 
241
  async def all_keys(self) -> list[dict]:
242
- if self.namespace.endswith("llm_response_cache"):
243
  sql = "select workspace,mode,id from lightrag_llm_cache"
244
  res = await self.db.query(sql, multirows=True)
245
  return res
@@ -251,7 +252,7 @@ class PGKVStorage(BaseKVStorage):
251
  async def filter_keys(self, keys: List[str]) -> Set[str]:
252
  """Filter out duplicated content"""
253
  sql = SQL_TEMPLATES["filter_keys"].format(
254
- table_name=NAMESPACE_TABLE_MAP[self.namespace],
255
  ids=",".join([f"'{id}'" for id in keys]),
256
  )
257
  params = {"workspace": self.db.workspace}
@@ -270,9 +271,9 @@ class PGKVStorage(BaseKVStorage):
270
 
271
  ################ INSERT METHODS ################
272
  async def upsert(self, data: Dict[str, dict]):
273
- if self.namespace.endswith("text_chunks"):
274
  pass
275
- elif self.namespace.endswith("full_docs"):
276
  for k, v in data.items():
277
  upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
278
  _data = {
@@ -281,7 +282,7 @@ class PGKVStorage(BaseKVStorage):
281
  "workspace": self.db.workspace,
282
  }
283
  await self.db.execute(upsert_sql, _data)
284
- elif self.namespace.endswith("llm_response_cache"):
285
  for mode, items in data.items():
286
  for k, v in items.items():
287
  upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
@@ -296,12 +297,11 @@ class PGKVStorage(BaseKVStorage):
296
  await self.db.execute(upsert_sql, _data)
297
 
298
  async def index_done_callback(self):
299
- for n in ("full_docs", "text_chunks"):
300
- if self.namespace.endswith(n):
301
- logger.info(
302
- "full doc and chunk data had been saved into postgresql db!"
303
- )
304
- break
305
 
306
 
307
  @dataclass
@@ -393,11 +393,11 @@ class PGVectorStorage(BaseVectorStorage):
393
  for i, d in enumerate(list_data):
394
  d["__vector__"] = embeddings[i]
395
  for item in list_data:
396
- if self.namespace.endswith("chunks"):
397
  upsert_sql, data = self._upsert_chunks(item)
398
- elif self.namespace.endswith("entities"):
399
  upsert_sql, data = self._upsert_entities(item)
400
- elif self.namespace.endswith("relationships"):
401
  upsert_sql, data = self._upsert_relationships(item)
402
  else:
403
  raise ValueError(f"{self.namespace} is not supported")
@@ -1027,16 +1027,22 @@ class PGGraphStorage(BaseGraphStorage):
1027
 
1028
 
1029
  NAMESPACE_TABLE_MAP = {
1030
- "full_docs": "LIGHTRAG_DOC_FULL",
1031
- "text_chunks": "LIGHTRAG_DOC_CHUNKS",
1032
- "chunks": "LIGHTRAG_DOC_CHUNKS",
1033
- "entities": "LIGHTRAG_VDB_ENTITY",
1034
- "relationships": "LIGHTRAG_VDB_RELATION",
1035
- "doc_status": "LIGHTRAG_DOC_STATUS",
1036
- "llm_response_cache": "LIGHTRAG_LLM_CACHE",
1037
  }
1038
 
1039
 
 
 
 
 
 
 
1040
  TABLES = {
1041
  "LIGHTRAG_DOC_FULL": {
1042
  "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
 
32
  BaseGraphStorage,
33
  T,
34
  )
35
+ from ..namespace import NameSpace, is_namespace
36
 
37
  if sys.platform.startswith("win"):
38
  import asyncio.windows_events
 
188
  """Get doc_full data by id."""
189
  sql = SQL_TEMPLATES["get_by_id_" + self.namespace]
190
  params = {"workspace": self.db.workspace, "id": id}
191
+ if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
192
  array_res = await self.db.query(sql, params, multirows=True)
193
  res = {}
194
  for row in array_res:
 
204
  """Specifically for llm_response_cache."""
205
  sql = SQL_TEMPLATES["get_by_mode_id_" + self.namespace]
206
  params = {"workspace": self.db.workspace, mode: mode, "id": id}
207
+ if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
208
  array_res = await self.db.query(sql, params, multirows=True)
209
  res = {}
210
  for row in array_res:
 
220
  ids=",".join([f"'{id}'" for id in ids])
221
  )
222
  params = {"workspace": self.db.workspace}
223
+ if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
224
  array_res = await self.db.query(sql, params, multirows=True)
225
  modes = set()
226
  dict_res: dict[str, dict] = {}
 
240
  return None
241
 
242
  async def all_keys(self) -> list[dict]:
243
+ if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
244
  sql = "select workspace,mode,id from lightrag_llm_cache"
245
  res = await self.db.query(sql, multirows=True)
246
  return res
 
252
  async def filter_keys(self, keys: List[str]) -> Set[str]:
253
  """Filter out duplicated content"""
254
  sql = SQL_TEMPLATES["filter_keys"].format(
255
+ table_name=namespace_to_table_name(self.namespace),
256
  ids=",".join([f"'{id}'" for id in keys]),
257
  )
258
  params = {"workspace": self.db.workspace}
 
271
 
272
  ################ INSERT METHODS ################
273
  async def upsert(self, data: Dict[str, dict]):
274
+ if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
275
  pass
276
+ elif is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
277
  for k, v in data.items():
278
  upsert_sql = SQL_TEMPLATES["upsert_doc_full"]
279
  _data = {
 
282
  "workspace": self.db.workspace,
283
  }
284
  await self.db.execute(upsert_sql, _data)
285
+ elif is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
286
  for mode, items in data.items():
287
  for k, v in items.items():
288
  upsert_sql = SQL_TEMPLATES["upsert_llm_response_cache"]
 
297
  await self.db.execute(upsert_sql, _data)
298
 
299
  async def index_done_callback(self):
300
+ if is_namespace(
301
+ self.namespace,
302
+ (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
303
+ ):
304
+ logger.info("full doc and chunk data had been saved into postgresql db!")
 
305
 
306
 
307
  @dataclass
 
393
  for i, d in enumerate(list_data):
394
  d["__vector__"] = embeddings[i]
395
  for item in list_data:
396
+ if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
397
  upsert_sql, data = self._upsert_chunks(item)
398
+ elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES):
399
  upsert_sql, data = self._upsert_entities(item)
400
+ elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS):
401
  upsert_sql, data = self._upsert_relationships(item)
402
  else:
403
  raise ValueError(f"{self.namespace} is not supported")
 
1027
 
1028
 
1029
  NAMESPACE_TABLE_MAP = {
1030
+ NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
1031
+ NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
1032
+ NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
1033
+ NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_VDB_ENTITY",
1034
+ NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_VDB_RELATION",
1035
+ NameSpace.DOC_STATUS: "LIGHTRAG_DOC_STATUS",
1036
+ NameSpace.KV_STORE_LLM_RESPONSE_CACHE: "LIGHTRAG_LLM_CACHE",
1037
  }
1038
 
1039
 
1040
+ def namespace_to_table_name(namespace: str) -> str:
1041
+ for k, v in NAMESPACE_TABLE_MAP.items():
1042
+ if is_namespace(namespace, k):
1043
+ return v
1044
+
1045
+
1046
  TABLES = {
1047
  "LIGHTRAG_DOC_FULL": {
1048
  "ddl": """CREATE TABLE LIGHTRAG_DOC_FULL (
lightrag/kg/postgres_impl_test.py CHANGED
@@ -12,7 +12,9 @@ if not pm.is_installed("asyncpg"):
12
  import asyncpg
13
  import psycopg
14
  from psycopg_pool import AsyncConnectionPool
15
- from lightrag.kg.postgres_impl import PostgreSQLDB, PGGraphStorage
 
 
16
 
17
  DB = "rag"
18
  USER = "rag"
@@ -76,7 +78,7 @@ db = PostgreSQLDB(
76
  async def query_with_age():
77
  await db.initdb()
78
  graph = PGGraphStorage(
79
- namespace="chunk_entity_relation",
80
  global_config={},
81
  embedding_func=None,
82
  )
@@ -92,7 +94,7 @@ async def query_with_age():
92
  async def create_edge_with_age():
93
  await db.initdb()
94
  graph = PGGraphStorage(
95
- namespace="chunk_entity_relation",
96
  global_config={},
97
  embedding_func=None,
98
  )
 
12
  import asyncpg
13
  import psycopg
14
  from psycopg_pool import AsyncConnectionPool
15
+
16
+ from ..kg.postgres_impl import PostgreSQLDB, PGGraphStorage
17
+ from ..namespace import NameSpace
18
 
19
  DB = "rag"
20
  USER = "rag"
 
78
  async def query_with_age():
79
  await db.initdb()
80
  graph = PGGraphStorage(
81
+ namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
82
  global_config={},
83
  embedding_func=None,
84
  )
 
94
  async def create_edge_with_age():
95
  await db.initdb()
96
  graph = PGGraphStorage(
97
+ namespace=NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION,
98
  global_config={},
99
  embedding_func=None,
100
  )
lightrag/kg/tidb_impl.py CHANGED
@@ -14,8 +14,9 @@ if not pm.is_installed("sqlalchemy"):
14
  from sqlalchemy import create_engine, text
15
  from tqdm import tqdm
16
 
17
- from lightrag.base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
18
- from lightrag.utils import logger
 
19
 
20
 
21
  class TiDB(object):
@@ -138,8 +139,8 @@ class TiDBKVStorage(BaseKVStorage):
138
  async def filter_keys(self, keys: list[str]) -> set[str]:
139
  """过滤掉重复内容"""
140
  SQL = SQL_TEMPLATES["filter_keys"].format(
141
- table_name=N_T[self.namespace],
142
- id_field=N_ID[self.namespace],
143
  ids=",".join([f"'{id}'" for id in keys]),
144
  )
145
  try:
@@ -160,7 +161,7 @@ class TiDBKVStorage(BaseKVStorage):
160
  async def upsert(self, data: dict[str, dict]):
161
  left_data = {k: v for k, v in data.items() if k not in self._data}
162
  self._data.update(left_data)
163
- if self.namespace.endswith("text_chunks"):
164
  list_data = [
165
  {
166
  "__id__": k,
@@ -196,7 +197,7 @@ class TiDBKVStorage(BaseKVStorage):
196
  )
197
  await self.db.execute(merge_sql, data)
198
 
199
- if self.namespace.endswith("full_docs"):
200
  merge_sql = SQL_TEMPLATES["upsert_doc_full"]
201
  data = []
202
  for k, v in self._data.items():
@@ -211,10 +212,11 @@ class TiDBKVStorage(BaseKVStorage):
211
  return left_data
212
 
213
  async def index_done_callback(self):
214
- for n in ("full_docs", "text_chunks"):
215
- if self.namespace.endswith(n):
216
- logger.info("full doc and chunk data had been saved into TiDB db!")
217
- break
 
218
 
219
 
220
  @dataclass
@@ -260,7 +262,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
260
  if not len(data):
261
  logger.warning("You insert an empty data to vector DB")
262
  return []
263
- if self.namespace.endswith("chunks"):
264
  return []
265
  logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
266
 
@@ -290,7 +292,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
290
  for i, d in enumerate(list_data):
291
  d["content_vector"] = embeddings[i]
292
 
293
- if self.namespace.endswith("entities"):
294
  data = []
295
  for item in list_data:
296
  param = {
@@ -311,7 +313,7 @@ class TiDBVectorDBStorage(BaseVectorStorage):
311
  merge_sql = SQL_TEMPLATES["insert_entity"]
312
  await self.db.execute(merge_sql, data)
313
 
314
- elif self.namespace.endswith("relationships"):
315
  data = []
316
  for item in list_data:
317
  param = {
@@ -470,20 +472,33 @@ class TiDBGraphStorage(BaseGraphStorage):
470
 
471
 
472
  N_T = {
473
- "full_docs": "LIGHTRAG_DOC_FULL",
474
- "text_chunks": "LIGHTRAG_DOC_CHUNKS",
475
- "chunks": "LIGHTRAG_DOC_CHUNKS",
476
- "entities": "LIGHTRAG_GRAPH_NODES",
477
- "relationships": "LIGHTRAG_GRAPH_EDGES",
478
  }
479
  N_ID = {
480
- "full_docs": "doc_id",
481
- "text_chunks": "chunk_id",
482
- "chunks": "chunk_id",
483
- "entities": "entity_id",
484
- "relationships": "relation_id",
485
  }
486
 
 
 
 
 
 
 
 
 
 
 
 
 
 
487
  TABLES = {
488
  "LIGHTRAG_DOC_FULL": {
489
  "ddl": """
 
14
  from sqlalchemy import create_engine, text
15
  from tqdm import tqdm
16
 
17
+ from ..base import BaseVectorStorage, BaseKVStorage, BaseGraphStorage
18
+ from ..utils import logger
19
+ from ..namespace import NameSpace, is_namespace
20
 
21
 
22
  class TiDB(object):
 
139
  async def filter_keys(self, keys: list[str]) -> set[str]:
140
  """过滤掉重复内容"""
141
  SQL = SQL_TEMPLATES["filter_keys"].format(
142
+ table_name=namespace_to_table_name(self.namespace),
143
+ id_field=namespace_to_id(self.namespace),
144
  ids=",".join([f"'{id}'" for id in keys]),
145
  )
146
  try:
 
161
  async def upsert(self, data: dict[str, dict]):
162
  left_data = {k: v for k, v in data.items() if k not in self._data}
163
  self._data.update(left_data)
164
+ if is_namespace(self.namespace, NameSpace.KV_STORE_TEXT_CHUNKS):
165
  list_data = [
166
  {
167
  "__id__": k,
 
197
  )
198
  await self.db.execute(merge_sql, data)
199
 
200
+ if is_namespace(self.namespace, NameSpace.KV_STORE_FULL_DOCS):
201
  merge_sql = SQL_TEMPLATES["upsert_doc_full"]
202
  data = []
203
  for k, v in self._data.items():
 
212
  return left_data
213
 
214
  async def index_done_callback(self):
215
+ if is_namespace(
216
+ self.namespace,
217
+ (NameSpace.KV_STORE_FULL_DOCS, NameSpace.KV_STORE_TEXT_CHUNKS),
218
+ ):
219
+ logger.info("full doc and chunk data had been saved into TiDB db!")
220
 
221
 
222
  @dataclass
 
262
  if not len(data):
263
  logger.warning("You insert an empty data to vector DB")
264
  return []
265
+ if is_namespace(self.namespace, NameSpace.VECTOR_STORE_CHUNKS):
266
  return []
267
  logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
268
 
 
292
  for i, d in enumerate(list_data):
293
  d["content_vector"] = embeddings[i]
294
 
295
+ if is_namespace(self.namespace, NameSpace.VECTOR_STORE_ENTITIES):
296
  data = []
297
  for item in list_data:
298
  param = {
 
313
  merge_sql = SQL_TEMPLATES["insert_entity"]
314
  await self.db.execute(merge_sql, data)
315
 
316
+ elif is_namespace(self.namespace, NameSpace.VECTOR_STORE_RELATIONSHIPS):
317
  data = []
318
  for item in list_data:
319
  param = {
 
472
 
473
 
474
  N_T = {
475
+ NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
476
+ NameSpace.KV_STORE_TEXT_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
477
+ NameSpace.VECTOR_STORE_CHUNKS: "LIGHTRAG_DOC_CHUNKS",
478
+ NameSpace.VECTOR_STORE_ENTITIES: "LIGHTRAG_GRAPH_NODES",
479
+ NameSpace.VECTOR_STORE_RELATIONSHIPS: "LIGHTRAG_GRAPH_EDGES",
480
  }
481
  N_ID = {
482
+ NameSpace.KV_STORE_FULL_DOCS: "doc_id",
483
+ NameSpace.KV_STORE_TEXT_CHUNKS: "chunk_id",
484
+ NameSpace.VECTOR_STORE_CHUNKS: "chunk_id",
485
+ NameSpace.VECTOR_STORE_ENTITIES: "entity_id",
486
+ NameSpace.VECTOR_STORE_RELATIONSHIPS: "relation_id",
487
  }
488
 
489
+
490
+ def namespace_to_table_name(namespace: str) -> str:
491
+ for k, v in N_T.items():
492
+ if is_namespace(namespace, k):
493
+ return v
494
+
495
+
496
+ def namespace_to_id(namespace: str) -> str:
497
+ for k, v in N_ID.items():
498
+ if is_namespace(namespace, k):
499
+ return v
500
+
501
+
502
  TABLES = {
503
  "LIGHTRAG_DOC_FULL": {
504
  "ddl": """
lightrag/lightrag.py CHANGED
@@ -35,6 +35,8 @@ from .base import (
35
  DocStatus,
36
  )
37
 
 
 
38
  from .prompt import GRAPH_FIELD_SEP
39
 
40
  STORAGES = {
@@ -228,8 +230,13 @@ class LightRAG:
228
  self.graph_storage_cls, global_config=global_config
229
  )
230
 
 
 
 
 
 
231
  self.llm_response_cache = self.key_string_value_json_storage_cls(
232
- namespace=self.namespace_prefix + "llm_response_cache",
233
  embedding_func=self.embedding_func,
234
  )
235
 
@@ -237,34 +244,33 @@ class LightRAG:
237
  # add embedding func by walter
238
  ####
239
  self.full_docs = self.key_string_value_json_storage_cls(
240
- namespace=self.namespace_prefix + "full_docs",
241
  embedding_func=self.embedding_func,
242
  )
243
  self.text_chunks = self.key_string_value_json_storage_cls(
244
- namespace=self.namespace_prefix + "text_chunks",
245
  embedding_func=self.embedding_func,
246
  )
247
  self.chunk_entity_relation_graph = self.graph_storage_cls(
248
- namespace=self.namespace_prefix + "chunk_entity_relation",
249
  embedding_func=self.embedding_func,
250
  )
251
-
252
  ####
253
  # add embedding func by walter over
254
  ####
255
 
256
  self.entities_vdb = self.vector_db_storage_cls(
257
- namespace=self.namespace_prefix + "entities",
258
  embedding_func=self.embedding_func,
259
  meta_fields={"entity_name"},
260
  )
261
  self.relationships_vdb = self.vector_db_storage_cls(
262
- namespace=self.namespace_prefix + "relationships",
263
  embedding_func=self.embedding_func,
264
  meta_fields={"src_id", "tgt_id"},
265
  )
266
  self.chunks_vdb = self.vector_db_storage_cls(
267
- namespace=self.namespace_prefix + "chunks",
268
  embedding_func=self.embedding_func,
269
  )
270
 
@@ -274,7 +280,7 @@ class LightRAG:
274
  hashing_kv = self.llm_response_cache
275
  else:
276
  hashing_kv = self.key_string_value_json_storage_cls(
277
- namespace=self.namespace_prefix + "llm_response_cache",
278
  embedding_func=self.embedding_func,
279
  )
280
 
@@ -289,7 +295,7 @@ class LightRAG:
289
  # Initialize document status storage
290
  self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
291
  self.doc_status = self.doc_status_storage_cls(
292
- namespace=self.namespace_prefix + "doc_status",
293
  global_config=global_config,
294
  embedding_func=None,
295
  )
@@ -925,7 +931,7 @@ class LightRAG:
925
  if self.llm_response_cache
926
  and hasattr(self.llm_response_cache, "global_config")
927
  else self.key_string_value_json_storage_cls(
928
- namespace=self.namespace_prefix + "llm_response_cache",
929
  global_config=asdict(self),
930
  embedding_func=self.embedding_func,
931
  ),
@@ -942,7 +948,7 @@ class LightRAG:
942
  if self.llm_response_cache
943
  and hasattr(self.llm_response_cache, "global_config")
944
  else self.key_string_value_json_storage_cls(
945
- namespace=self.namespace_prefix + "llm_response_cache",
946
  global_config=asdict(self),
947
  embedding_func=self.embedding_func,
948
  ),
@@ -961,7 +967,7 @@ class LightRAG:
961
  if self.llm_response_cache
962
  and hasattr(self.llm_response_cache, "global_config")
963
  else self.key_string_value_json_storage_cls(
964
- namespace=self.namespace_prefix + "llm_response_cache",
965
  global_config=asdict(self),
966
  embedding_func=self.embedding_func,
967
  ),
@@ -1002,7 +1008,7 @@ class LightRAG:
1002
  global_config=asdict(self),
1003
  hashing_kv=self.llm_response_cache
1004
  or self.key_string_value_json_storage_cls(
1005
- namespace=self.namespace_prefix + "llm_response_cache",
1006
  global_config=asdict(self),
1007
  embedding_func=self.embedding_func,
1008
  ),
@@ -1033,7 +1039,7 @@ class LightRAG:
1033
  if self.llm_response_cache
1034
  and hasattr(self.llm_response_cache, "global_config")
1035
  else self.key_string_value_json_storage_cls(
1036
- namespace=self.namespace_prefix + "llm_response_cache",
1037
  global_config=asdict(self),
1038
  embedding_func=self.embedding_funcne,
1039
  ),
@@ -1049,7 +1055,7 @@ class LightRAG:
1049
  if self.llm_response_cache
1050
  and hasattr(self.llm_response_cache, "global_config")
1051
  else self.key_string_value_json_storage_cls(
1052
- namespace=self.namespace_prefix + "llm_response_cache",
1053
  global_config=asdict(self),
1054
  embedding_func=self.embedding_func,
1055
  ),
@@ -1068,7 +1074,7 @@ class LightRAG:
1068
  if self.llm_response_cache
1069
  and hasattr(self.llm_response_cache, "global_config")
1070
  else self.key_string_value_json_storage_cls(
1071
- namespace=self.namespace_prefix + "llm_response_cache",
1072
  global_config=asdict(self),
1073
  embedding_func=self.embedding_func,
1074
  ),
 
35
  DocStatus,
36
  )
37
 
38
+ from .namespace import NameSpace, make_namespace
39
+
40
  from .prompt import GRAPH_FIELD_SEP
41
 
42
  STORAGES = {
 
230
  self.graph_storage_cls, global_config=global_config
231
  )
232
 
233
+ self.json_doc_status_storage = self.key_string_value_json_storage_cls(
234
+ namespace=self.namespace_prefix + "json_doc_status_storage",
235
+ embedding_func=None,
236
+ )
237
+
238
  self.llm_response_cache = self.key_string_value_json_storage_cls(
239
+ namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
240
  embedding_func=self.embedding_func,
241
  )
242
 
 
244
  # add embedding func by walter
245
  ####
246
  self.full_docs = self.key_string_value_json_storage_cls(
247
+ namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_FULL_DOCS),
248
  embedding_func=self.embedding_func,
249
  )
250
  self.text_chunks = self.key_string_value_json_storage_cls(
251
+ namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_TEXT_CHUNKS),
252
  embedding_func=self.embedding_func,
253
  )
254
  self.chunk_entity_relation_graph = self.graph_storage_cls(
255
+ namespace=make_namespace(self.namespace_prefix, NameSpace.GRAPH_STORE_CHUNK_ENTITY_RELATION),
256
  embedding_func=self.embedding_func,
257
  )
 
258
  ####
259
  # add embedding func by walter over
260
  ####
261
 
262
  self.entities_vdb = self.vector_db_storage_cls(
263
+ namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_ENTITIES),
264
  embedding_func=self.embedding_func,
265
  meta_fields={"entity_name"},
266
  )
267
  self.relationships_vdb = self.vector_db_storage_cls(
268
+ namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_RELATIONSHIPS),
269
  embedding_func=self.embedding_func,
270
  meta_fields={"src_id", "tgt_id"},
271
  )
272
  self.chunks_vdb = self.vector_db_storage_cls(
273
+ namespace=make_namespace(self.namespace_prefix, NameSpace.VECTOR_STORE_CHUNKS),
274
  embedding_func=self.embedding_func,
275
  )
276
 
 
280
  hashing_kv = self.llm_response_cache
281
  else:
282
  hashing_kv = self.key_string_value_json_storage_cls(
283
+ namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
284
  embedding_func=self.embedding_func,
285
  )
286
 
 
295
  # Initialize document status storage
296
  self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
297
  self.doc_status = self.doc_status_storage_cls(
298
+ namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
299
  global_config=global_config,
300
  embedding_func=None,
301
  )
 
931
  if self.llm_response_cache
932
  and hasattr(self.llm_response_cache, "global_config")
933
  else self.key_string_value_json_storage_cls(
934
+ namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
935
  global_config=asdict(self),
936
  embedding_func=self.embedding_func,
937
  ),
 
948
  if self.llm_response_cache
949
  and hasattr(self.llm_response_cache, "global_config")
950
  else self.key_string_value_json_storage_cls(
951
+ namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
952
  global_config=asdict(self),
953
  embedding_func=self.embedding_func,
954
  ),
 
967
  if self.llm_response_cache
968
  and hasattr(self.llm_response_cache, "global_config")
969
  else self.key_string_value_json_storage_cls(
970
+ namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
971
  global_config=asdict(self),
972
  embedding_func=self.embedding_func,
973
  ),
 
1008
  global_config=asdict(self),
1009
  hashing_kv=self.llm_response_cache
1010
  or self.key_string_value_json_storage_cls(
1011
+ namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
1012
  global_config=asdict(self),
1013
  embedding_func=self.embedding_func,
1014
  ),
 
1039
  if self.llm_response_cache
1040
  and hasattr(self.llm_response_cache, "global_config")
1041
  else self.key_string_value_json_storage_cls(
1042
+ namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
1043
  global_config=asdict(self),
1044
  embedding_func=self.embedding_funcne,
1045
  ),
 
1055
  if self.llm_response_cache
1056
  and hasattr(self.llm_response_cache, "global_config")
1057
  else self.key_string_value_json_storage_cls(
1058
+ namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
1059
  global_config=asdict(self),
1060
  embedding_func=self.embedding_func,
1061
  ),
 
1074
  if self.llm_response_cache
1075
  and hasattr(self.llm_response_cache, "global_config")
1076
  else self.key_string_value_json_storage_cls(
1077
+ namespace=make_namespace(self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE),
1078
  global_config=asdict(self),
1079
  embedding_func=self.embedding_func,
1080
  ),
lightrag/namespace.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable
2
+
3
+
4
+ class NameSpace:
5
+ KV_STORE_FULL_DOCS = "full_docs"
6
+ KV_STORE_TEXT_CHUNKS = "text_chunks"
7
+ KV_STORE_LLM_RESPONSE_CACHE = "llm_response_cache"
8
+
9
+ VECTOR_STORE_ENTITIES = "entities"
10
+ VECTOR_STORE_RELATIONSHIPS = "relationships"
11
+ VECTOR_STORE_CHUNKS = "chunks"
12
+
13
+ GRAPH_STORE_CHUNK_ENTITY_RELATION = "chunk_entity_relation"
14
+
15
+ DOC_STATUS = "doc_status"
16
+
17
+
18
+ def make_namespace(prefix: str, base_namespace: str):
19
+ return prefix + base_namespace
20
+
21
+
22
+ def is_namespace(namespace: str, base_namespace: str | Iterable[str]):
23
+ if isinstance(base_namespace, str):
24
+ return namespace.endswith(base_namespace)
25
+ return any(is_namespace(namespace, ns) for ns in base_namespace)