YanSte commited on
Commit
516470b
·
1 Parent(s): 003f25b

cleaned code

Browse files
lightrag/base.py CHANGED
@@ -84,9 +84,6 @@ class BaseVectorStorage(StorageNameSpace):
84
  class BaseKVStorage(StorageNameSpace):
85
  embedding_func: EmbeddingFunc
86
 
87
- async def all_keys(self) -> list[str]:
88
- raise NotImplementedError
89
-
90
  async def get_by_id(self, id: str) -> dict[str, Any]:
91
  raise NotImplementedError
92
 
@@ -103,9 +100,6 @@ class BaseKVStorage(StorageNameSpace):
103
  async def drop(self) -> None:
104
  raise NotImplementedError
105
 
106
- async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
107
- raise NotImplementedError
108
-
109
 
110
  @dataclass
111
  class BaseGraphStorage(StorageNameSpace):
 
84
  class BaseKVStorage(StorageNameSpace):
85
  embedding_func: EmbeddingFunc
86
 
 
 
 
87
  async def get_by_id(self, id: str) -> dict[str, Any]:
88
  raise NotImplementedError
89
 
 
100
  async def drop(self) -> None:
101
  raise NotImplementedError
102
 
 
 
 
103
 
104
  @dataclass
105
  class BaseGraphStorage(StorageNameSpace):
lightrag/kg/json_kv_impl.py CHANGED
@@ -1,7 +1,7 @@
1
  import asyncio
2
  import os
3
  from dataclasses import dataclass
4
- from typing import Any, Union
5
 
6
  from lightrag.utils import (
7
  logger,
@@ -21,10 +21,7 @@ class JsonKVStorage(BaseKVStorage):
21
  self._data: dict[str, Any] = load_json(self._file_name) or {}
22
  self._lock = asyncio.Lock()
23
  logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
24
-
25
- async def all_keys(self) -> list[str]:
26
- return list(self._data.keys())
27
-
28
  async def index_done_callback(self):
29
  write_json(self._data, self._file_name)
30
 
@@ -49,8 +46,4 @@ class JsonKVStorage(BaseKVStorage):
49
  self._data.update(left_data)
50
 
51
  async def drop(self) -> None:
52
- self._data = {}
53
-
54
- async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
55
- result = [v for _, v in self._data.items() if v["status"] == status]
56
- return result if result else None
 
1
  import asyncio
2
  import os
3
  from dataclasses import dataclass
4
+ from typing import Any
5
 
6
  from lightrag.utils import (
7
  logger,
 
21
  self._data: dict[str, Any] = load_json(self._file_name) or {}
22
  self._lock = asyncio.Lock()
23
  logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
24
+
 
 
 
25
  async def index_done_callback(self):
26
  write_json(self._data, self._file_name)
27
 
 
46
  self._data.update(left_data)
47
 
48
  async def drop(self) -> None:
49
+ self._data = {}
 
 
 
 
lightrag/kg/mongo_impl.py CHANGED
@@ -29,9 +29,6 @@ class MongoKVStorage(BaseKVStorage):
29
  self._data = database.get_collection(self.namespace)
30
  logger.info(f"Use MongoDB as KV {self.namespace}")
31
 
32
- async def all_keys(self) -> list[str]:
33
- return [x["_id"] for x in self._data.find({}, {"_id": 1})]
34
-
35
  async def get_by_id(self, id: str) -> dict[str, Any]:
36
  return self._data.find_one({"_id": id})
37
 
@@ -77,11 +74,6 @@ class MongoKVStorage(BaseKVStorage):
77
  """Drop the collection"""
78
  await self._data.drop()
79
 
80
- async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
81
- """Get documents by status and ids"""
82
- return self._data.find({"status": status})
83
-
84
-
85
  @dataclass
86
  class MongoGraphStorage(BaseGraphStorage):
87
  """
 
29
  self._data = database.get_collection(self.namespace)
30
  logger.info(f"Use MongoDB as KV {self.namespace}")
31
 
 
 
 
32
  async def get_by_id(self, id: str) -> dict[str, Any]:
33
  return self._data.find_one({"_id": id})
34
 
 
74
  """Drop the collection"""
75
  await self._data.drop()
76
 
 
 
 
 
 
77
  @dataclass
78
  class MongoGraphStorage(BaseGraphStorage):
79
  """
lightrag/kg/oracle_impl.py CHANGED
@@ -229,12 +229,6 @@ class OracleKVStorage(BaseKVStorage):
229
  res = [{k: v} for k, v in dict_res.items()]
230
  return res
231
 
232
- async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
233
- """Specifically for llm_response_cache."""
234
- SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
235
- params = {"workspace": self.db.workspace, "status": status}
236
- return await self.db.query(SQL, params, multirows=True)
237
-
238
  async def filter_keys(self, keys: list[str]) -> set[str]:
239
  """Return keys that don't exist in storage"""
240
  SQL = SQL_TEMPLATES["filter_keys"].format(
 
229
  res = [{k: v} for k, v in dict_res.items()]
230
  return res
231
 
 
 
 
 
 
 
232
  async def filter_keys(self, keys: list[str]) -> set[str]:
233
  """Return keys that don't exist in storage"""
234
  SQL = SQL_TEMPLATES["filter_keys"].format(
lightrag/kg/postgres_impl.py CHANGED
@@ -237,16 +237,6 @@ class PGKVStorage(BaseKVStorage):
237
  params = {"workspace": self.db.workspace, "status": status}
238
  return await self.db.query(SQL, params, multirows=True)
239
 
240
- async def all_keys(self) -> list[dict]:
241
- if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
242
- sql = "select workspace,mode,id from lightrag_llm_cache"
243
- res = await self.db.query(sql, multirows=True)
244
- return res
245
- else:
246
- logger.error(
247
- f"all_keys is only implemented for llm_response_cache, not for {self.namespace}"
248
- )
249
-
250
  async def filter_keys(self, keys: List[str]) -> Set[str]:
251
  """Filter out duplicated content"""
252
  sql = SQL_TEMPLATES["filter_keys"].format(
 
237
  params = {"workspace": self.db.workspace, "status": status}
238
  return await self.db.query(SQL, params, multirows=True)
239
 
 
 
 
 
 
 
 
 
 
 
240
  async def filter_keys(self, keys: List[str]) -> Set[str]:
241
  """Filter out duplicated content"""
242
  sql = SQL_TEMPLATES["filter_keys"].format(
lightrag/kg/redis_impl.py CHANGED
@@ -1,5 +1,5 @@
1
  import os
2
- from typing import Any, Union
3
  from tqdm.asyncio import tqdm as tqdm_async
4
  from dataclasses import dataclass
5
  import pipmaster as pm
@@ -20,11 +20,7 @@ class RedisKVStorage(BaseKVStorage):
20
  redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
21
  self._redis = Redis.from_url(redis_url, decode_responses=True)
22
  logger.info(f"Use Redis as KV {self.namespace}")
23
-
24
- async def all_keys(self) -> list[str]:
25
- keys = await self._redis.keys(f"{self.namespace}:*")
26
- return [key.split(":", 1)[-1] for key in keys]
27
-
28
  async def get_by_id(self, id):
29
  data = await self._redis.get(f"{self.namespace}:{id}")
30
  return json.loads(data) if data else None
@@ -57,11 +53,4 @@ class RedisKVStorage(BaseKVStorage):
57
  async def drop(self) -> None:
58
  keys = await self._redis.keys(f"{self.namespace}:*")
59
  if keys:
60
- await self._redis.delete(*keys)
61
-
62
- async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
63
- pipe = self._redis.pipeline()
64
- for key in await self._redis.keys(f"{self.namespace}:*"):
65
- pipe.hgetall(key)
66
- results = await pipe.execute()
67
- return [data for data in results if data.get("status") == status] or None
 
1
  import os
2
+ from typing import Any
3
  from tqdm.asyncio import tqdm as tqdm_async
4
  from dataclasses import dataclass
5
  import pipmaster as pm
 
20
  redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
21
  self._redis = Redis.from_url(redis_url, decode_responses=True)
22
  logger.info(f"Use Redis as KV {self.namespace}")
23
+
 
 
 
 
24
  async def get_by_id(self, id):
25
  data = await self._redis.get(f"{self.namespace}:{id}")
26
  return json.loads(data) if data else None
 
53
  async def drop(self) -> None:
54
  keys = await self._redis.keys(f"{self.namespace}:*")
55
  if keys:
56
+ await self._redis.delete(*keys)
 
 
 
 
 
 
 
lightrag/lightrag.py CHANGED
@@ -29,6 +29,7 @@ from .base import (
29
  BaseKVStorage,
30
  BaseVectorStorage,
31
  DocStatus,
 
32
  QueryParam,
33
  StorageNameSpace,
34
  )
@@ -319,7 +320,7 @@ class LightRAG:
319
 
320
  # Initialize document status storage
321
  self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
322
- self.doc_status: BaseKVStorage = self.doc_status_storage_cls(
323
  namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
324
  global_config=global_config,
325
  embedding_func=None,
@@ -394,10 +395,8 @@ class LightRAG:
394
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
395
  split_by_character is None, this parameter is ignored.
396
  """
397
- await self.apipeline_process_documents(string_or_strings)
398
- await self.apipeline_process_enqueue_documents(
399
- split_by_character, split_by_character_only
400
- )
401
 
402
  def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
403
  loop = always_get_an_event_loop()
@@ -496,8 +495,13 @@ class LightRAG:
496
 
497
  # 3. Filter out already processed documents
498
  add_doc_keys: set[str] = set()
499
- excluded_ids = await self.doc_status.all_keys()
 
 
 
 
500
  add_doc_keys = new_docs.keys() - excluded_ids
 
501
  new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys}
502
 
503
  if not new_docs:
@@ -513,12 +517,12 @@ class LightRAG:
513
  to_process_doc_keys: list[str] = []
514
 
515
  # Fetch failed documents
516
- failed_docs = await self.doc_status.get_by_status(status=DocStatus.FAILED)
517
  if failed_docs:
518
  to_process_doc_keys.extend([doc["id"] for doc in failed_docs])
519
 
520
  # Fetch pending documents
521
- pending_docs = await self.doc_status.get_by_status(status=DocStatus.PENDING)
522
  if pending_docs:
523
  to_process_doc_keys.extend([doc["id"] for doc in pending_docs])
524
 
 
29
  BaseKVStorage,
30
  BaseVectorStorage,
31
  DocStatus,
32
+ DocStatusStorage,
33
  QueryParam,
34
  StorageNameSpace,
35
  )
 
320
 
321
  # Initialize document status storage
322
  self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
323
+ self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
324
  namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
325
  global_config=global_config,
326
  embedding_func=None,
 
395
  split_by_character_only: if split_by_character_only is True, split the string by character only, when
396
  split_by_character is None, this parameter is ignored.
397
  """
398
+ await self.apipeline_enqueue_documents(string_or_strings)
399
+ await self.apipeline_process_enqueue_documents(split_by_character, split_by_character_only)
 
 
400
 
401
  def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
402
  loop = always_get_an_event_loop()
 
495
 
496
  # 3. Filter out already processed documents
497
  add_doc_keys: set[str] = set()
498
+ # Get docs ids
499
+ in_process_keys = list(new_docs.keys())
500
+ # Get in progress docs ids
501
+ excluded_ids = await self.doc_status.get_by_ids(in_process_keys)
502
+ # Exclude already in process
503
  add_doc_keys = new_docs.keys() - excluded_ids
504
+ # Filter
505
  new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys}
506
 
507
  if not new_docs:
 
517
  to_process_doc_keys: list[str] = []
518
 
519
  # Fetch failed documents
520
+ failed_docs = await self.doc_status.get_failed_docs()
521
  if failed_docs:
522
  to_process_doc_keys.extend([doc["id"] for doc in failed_docs])
523
 
524
  # Fetch pending documents
525
+ pending_docs = await self.doc_status.get_pending_docs()
526
  if pending_docs:
527
  to_process_doc_keys.extend([doc["id"] for doc in pending_docs])
528