cleaned code
Browse files- lightrag/base.py +0 -6
- lightrag/kg/json_kv_impl.py +3 -10
- lightrag/kg/mongo_impl.py +0 -8
- lightrag/kg/oracle_impl.py +0 -6
- lightrag/kg/postgres_impl.py +0 -10
- lightrag/kg/redis_impl.py +3 -14
- lightrag/lightrag.py +12 -8
lightrag/base.py
CHANGED
@@ -84,9 +84,6 @@ class BaseVectorStorage(StorageNameSpace):
|
|
84 |
class BaseKVStorage(StorageNameSpace):
|
85 |
embedding_func: EmbeddingFunc
|
86 |
|
87 |
-
async def all_keys(self) -> list[str]:
|
88 |
-
raise NotImplementedError
|
89 |
-
|
90 |
async def get_by_id(self, id: str) -> dict[str, Any]:
|
91 |
raise NotImplementedError
|
92 |
|
@@ -103,9 +100,6 @@ class BaseKVStorage(StorageNameSpace):
|
|
103 |
async def drop(self) -> None:
|
104 |
raise NotImplementedError
|
105 |
|
106 |
-
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
107 |
-
raise NotImplementedError
|
108 |
-
|
109 |
|
110 |
@dataclass
|
111 |
class BaseGraphStorage(StorageNameSpace):
|
|
|
84 |
class BaseKVStorage(StorageNameSpace):
|
85 |
embedding_func: EmbeddingFunc
|
86 |
|
|
|
|
|
|
|
87 |
async def get_by_id(self, id: str) -> dict[str, Any]:
|
88 |
raise NotImplementedError
|
89 |
|
|
|
100 |
async def drop(self) -> None:
|
101 |
raise NotImplementedError
|
102 |
|
|
|
|
|
|
|
103 |
|
104 |
@dataclass
|
105 |
class BaseGraphStorage(StorageNameSpace):
|
lightrag/kg/json_kv_impl.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
from dataclasses import dataclass
|
4 |
-
from typing import Any
|
5 |
|
6 |
from lightrag.utils import (
|
7 |
logger,
|
@@ -21,10 +21,7 @@ class JsonKVStorage(BaseKVStorage):
|
|
21 |
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
22 |
self._lock = asyncio.Lock()
|
23 |
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
24 |
-
|
25 |
-
async def all_keys(self) -> list[str]:
|
26 |
-
return list(self._data.keys())
|
27 |
-
|
28 |
async def index_done_callback(self):
|
29 |
write_json(self._data, self._file_name)
|
30 |
|
@@ -49,8 +46,4 @@ class JsonKVStorage(BaseKVStorage):
|
|
49 |
self._data.update(left_data)
|
50 |
|
51 |
async def drop(self) -> None:
|
52 |
-
self._data = {}
|
53 |
-
|
54 |
-
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
55 |
-
result = [v for _, v in self._data.items() if v["status"] == status]
|
56 |
-
return result if result else None
|
|
|
1 |
import asyncio
|
2 |
import os
|
3 |
from dataclasses import dataclass
|
4 |
+
from typing import Any
|
5 |
|
6 |
from lightrag.utils import (
|
7 |
logger,
|
|
|
21 |
self._data: dict[str, Any] = load_json(self._file_name) or {}
|
22 |
self._lock = asyncio.Lock()
|
23 |
logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
|
24 |
+
|
|
|
|
|
|
|
25 |
async def index_done_callback(self):
|
26 |
write_json(self._data, self._file_name)
|
27 |
|
|
|
46 |
self._data.update(left_data)
|
47 |
|
48 |
async def drop(self) -> None:
|
49 |
+
self._data = {}
|
|
|
|
|
|
|
|
lightrag/kg/mongo_impl.py
CHANGED
@@ -29,9 +29,6 @@ class MongoKVStorage(BaseKVStorage):
|
|
29 |
self._data = database.get_collection(self.namespace)
|
30 |
logger.info(f"Use MongoDB as KV {self.namespace}")
|
31 |
|
32 |
-
async def all_keys(self) -> list[str]:
|
33 |
-
return [x["_id"] for x in self._data.find({}, {"_id": 1})]
|
34 |
-
|
35 |
async def get_by_id(self, id: str) -> dict[str, Any]:
|
36 |
return self._data.find_one({"_id": id})
|
37 |
|
@@ -77,11 +74,6 @@ class MongoKVStorage(BaseKVStorage):
|
|
77 |
"""Drop the collection"""
|
78 |
await self._data.drop()
|
79 |
|
80 |
-
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
81 |
-
"""Get documents by status and ids"""
|
82 |
-
return self._data.find({"status": status})
|
83 |
-
|
84 |
-
|
85 |
@dataclass
|
86 |
class MongoGraphStorage(BaseGraphStorage):
|
87 |
"""
|
|
|
29 |
self._data = database.get_collection(self.namespace)
|
30 |
logger.info(f"Use MongoDB as KV {self.namespace}")
|
31 |
|
|
|
|
|
|
|
32 |
async def get_by_id(self, id: str) -> dict[str, Any]:
|
33 |
return self._data.find_one({"_id": id})
|
34 |
|
|
|
74 |
"""Drop the collection"""
|
75 |
await self._data.drop()
|
76 |
|
|
|
|
|
|
|
|
|
|
|
77 |
@dataclass
|
78 |
class MongoGraphStorage(BaseGraphStorage):
|
79 |
"""
|
lightrag/kg/oracle_impl.py
CHANGED
@@ -229,12 +229,6 @@ class OracleKVStorage(BaseKVStorage):
|
|
229 |
res = [{k: v} for k, v in dict_res.items()]
|
230 |
return res
|
231 |
|
232 |
-
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
233 |
-
"""Specifically for llm_response_cache."""
|
234 |
-
SQL = SQL_TEMPLATES["get_by_status_" + self.namespace]
|
235 |
-
params = {"workspace": self.db.workspace, "status": status}
|
236 |
-
return await self.db.query(SQL, params, multirows=True)
|
237 |
-
|
238 |
async def filter_keys(self, keys: list[str]) -> set[str]:
|
239 |
"""Return keys that don't exist in storage"""
|
240 |
SQL = SQL_TEMPLATES["filter_keys"].format(
|
|
|
229 |
res = [{k: v} for k, v in dict_res.items()]
|
230 |
return res
|
231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
232 |
async def filter_keys(self, keys: list[str]) -> set[str]:
|
233 |
"""Return keys that don't exist in storage"""
|
234 |
SQL = SQL_TEMPLATES["filter_keys"].format(
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -237,16 +237,6 @@ class PGKVStorage(BaseKVStorage):
|
|
237 |
params = {"workspace": self.db.workspace, "status": status}
|
238 |
return await self.db.query(SQL, params, multirows=True)
|
239 |
|
240 |
-
async def all_keys(self) -> list[dict]:
|
241 |
-
if is_namespace(self.namespace, NameSpace.KV_STORE_LLM_RESPONSE_CACHE):
|
242 |
-
sql = "select workspace,mode,id from lightrag_llm_cache"
|
243 |
-
res = await self.db.query(sql, multirows=True)
|
244 |
-
return res
|
245 |
-
else:
|
246 |
-
logger.error(
|
247 |
-
f"all_keys is only implemented for llm_response_cache, not for {self.namespace}"
|
248 |
-
)
|
249 |
-
|
250 |
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
251 |
"""Filter out duplicated content"""
|
252 |
sql = SQL_TEMPLATES["filter_keys"].format(
|
|
|
237 |
params = {"workspace": self.db.workspace, "status": status}
|
238 |
return await self.db.query(SQL, params, multirows=True)
|
239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
async def filter_keys(self, keys: List[str]) -> Set[str]:
|
241 |
"""Filter out duplicated content"""
|
242 |
sql = SQL_TEMPLATES["filter_keys"].format(
|
lightrag/kg/redis_impl.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import os
|
2 |
-
from typing import Any
|
3 |
from tqdm.asyncio import tqdm as tqdm_async
|
4 |
from dataclasses import dataclass
|
5 |
import pipmaster as pm
|
@@ -20,11 +20,7 @@ class RedisKVStorage(BaseKVStorage):
|
|
20 |
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
|
21 |
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
22 |
logger.info(f"Use Redis as KV {self.namespace}")
|
23 |
-
|
24 |
-
async def all_keys(self) -> list[str]:
|
25 |
-
keys = await self._redis.keys(f"{self.namespace}:*")
|
26 |
-
return [key.split(":", 1)[-1] for key in keys]
|
27 |
-
|
28 |
async def get_by_id(self, id):
|
29 |
data = await self._redis.get(f"{self.namespace}:{id}")
|
30 |
return json.loads(data) if data else None
|
@@ -57,11 +53,4 @@ class RedisKVStorage(BaseKVStorage):
|
|
57 |
async def drop(self) -> None:
|
58 |
keys = await self._redis.keys(f"{self.namespace}:*")
|
59 |
if keys:
|
60 |
-
await self._redis.delete(*keys)
|
61 |
-
|
62 |
-
async def get_by_status(self, status: str) -> Union[list[dict[str, Any]], None]:
|
63 |
-
pipe = self._redis.pipeline()
|
64 |
-
for key in await self._redis.keys(f"{self.namespace}:*"):
|
65 |
-
pipe.hgetall(key)
|
66 |
-
results = await pipe.execute()
|
67 |
-
return [data for data in results if data.get("status") == status] or None
|
|
|
1 |
import os
|
2 |
+
from typing import Any
|
3 |
from tqdm.asyncio import tqdm as tqdm_async
|
4 |
from dataclasses import dataclass
|
5 |
import pipmaster as pm
|
|
|
20 |
redis_url = os.environ.get("REDIS_URI", "redis://localhost:6379")
|
21 |
self._redis = Redis.from_url(redis_url, decode_responses=True)
|
22 |
logger.info(f"Use Redis as KV {self.namespace}")
|
23 |
+
|
|
|
|
|
|
|
|
|
24 |
async def get_by_id(self, id):
|
25 |
data = await self._redis.get(f"{self.namespace}:{id}")
|
26 |
return json.loads(data) if data else None
|
|
|
53 |
async def drop(self) -> None:
|
54 |
keys = await self._redis.keys(f"{self.namespace}:*")
|
55 |
if keys:
|
56 |
+
await self._redis.delete(*keys)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lightrag/lightrag.py
CHANGED
@@ -29,6 +29,7 @@ from .base import (
|
|
29 |
BaseKVStorage,
|
30 |
BaseVectorStorage,
|
31 |
DocStatus,
|
|
|
32 |
QueryParam,
|
33 |
StorageNameSpace,
|
34 |
)
|
@@ -319,7 +320,7 @@ class LightRAG:
|
|
319 |
|
320 |
# Initialize document status storage
|
321 |
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
322 |
-
self.doc_status:
|
323 |
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
324 |
global_config=global_config,
|
325 |
embedding_func=None,
|
@@ -394,10 +395,8 @@ class LightRAG:
|
|
394 |
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
395 |
split_by_character is None, this parameter is ignored.
|
396 |
"""
|
397 |
-
await self.
|
398 |
-
await self.apipeline_process_enqueue_documents(
|
399 |
-
split_by_character, split_by_character_only
|
400 |
-
)
|
401 |
|
402 |
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
|
403 |
loop = always_get_an_event_loop()
|
@@ -496,8 +495,13 @@ class LightRAG:
|
|
496 |
|
497 |
# 3. Filter out already processed documents
|
498 |
add_doc_keys: set[str] = set()
|
499 |
-
|
|
|
|
|
|
|
|
|
500 |
add_doc_keys = new_docs.keys() - excluded_ids
|
|
|
501 |
new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys}
|
502 |
|
503 |
if not new_docs:
|
@@ -513,12 +517,12 @@ class LightRAG:
|
|
513 |
to_process_doc_keys: list[str] = []
|
514 |
|
515 |
# Fetch failed documents
|
516 |
-
failed_docs = await self.doc_status.
|
517 |
if failed_docs:
|
518 |
to_process_doc_keys.extend([doc["id"] for doc in failed_docs])
|
519 |
|
520 |
# Fetch pending documents
|
521 |
-
pending_docs = await self.doc_status.
|
522 |
if pending_docs:
|
523 |
to_process_doc_keys.extend([doc["id"] for doc in pending_docs])
|
524 |
|
|
|
29 |
BaseKVStorage,
|
30 |
BaseVectorStorage,
|
31 |
DocStatus,
|
32 |
+
DocStatusStorage,
|
33 |
QueryParam,
|
34 |
StorageNameSpace,
|
35 |
)
|
|
|
320 |
|
321 |
# Initialize document status storage
|
322 |
self.doc_status_storage_cls = self._get_storage_class(self.doc_status_storage)
|
323 |
+
self.doc_status: DocStatusStorage = self.doc_status_storage_cls(
|
324 |
namespace=make_namespace(self.namespace_prefix, NameSpace.DOC_STATUS),
|
325 |
global_config=global_config,
|
326 |
embedding_func=None,
|
|
|
395 |
split_by_character_only: if split_by_character_only is True, split the string by character only, when
|
396 |
split_by_character is None, this parameter is ignored.
|
397 |
"""
|
398 |
+
await self.apipeline_enqueue_documents(string_or_strings)
|
399 |
+
await self.apipeline_process_enqueue_documents(split_by_character, split_by_character_only)
|
|
|
|
|
400 |
|
401 |
def insert_custom_chunks(self, full_text: str, text_chunks: list[str]):
|
402 |
loop = always_get_an_event_loop()
|
|
|
495 |
|
496 |
# 3. Filter out already processed documents
|
497 |
add_doc_keys: set[str] = set()
|
498 |
+
# Get docs ids
|
499 |
+
in_process_keys = list(new_docs.keys())
|
500 |
+
# Get in progress docs ids
|
501 |
+
excluded_ids = await self.doc_status.get_by_ids(in_process_keys)
|
502 |
+
# Exclude already in process
|
503 |
add_doc_keys = new_docs.keys() - excluded_ids
|
504 |
+
# Filter
|
505 |
new_docs = {k: v for k, v in new_docs.items() if k in add_doc_keys}
|
506 |
|
507 |
if not new_docs:
|
|
|
517 |
to_process_doc_keys: list[str] = []
|
518 |
|
519 |
# Fetch failed documents
|
520 |
+
failed_docs = await self.doc_status.get_failed_docs()
|
521 |
if failed_docs:
|
522 |
to_process_doc_keys.extend([doc["id"] for doc in failed_docs])
|
523 |
|
524 |
# Fetch pending documents
|
525 |
+
pending_docs = await self.doc_status.get_pending_docs()
|
526 |
if pending_docs:
|
527 |
to_process_doc_keys.extend([doc["id"] for doc in pending_docs])
|
528 |
|