Merge pull request #444 from davidleon/fix/lazy_import
Browse files- lightrag/lightrag.py +13 -6
- lightrag/llm.py +1 -0
- lightrag/storage.py +10 -4
- lightrag/utils.py +20 -1
lightrag/lightrag.py
CHANGED
@@ -48,18 +48,25 @@ from .storage import (
|
|
48 |
|
49 |
|
50 |
def lazy_external_import(module_name: str, class_name: str):
|
51 |
-
"""Lazily import an external module
|
52 |
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
import importlib
|
55 |
|
56 |
# Import the module using importlib
|
57 |
-
module = importlib.import_module(module_name)
|
58 |
|
59 |
-
# Get the class from the module
|
60 |
-
|
|
|
61 |
|
62 |
-
# Return the import_class function itself, not its result
|
63 |
return import_class
|
64 |
|
65 |
|
|
|
48 |
|
49 |
|
50 |
def lazy_external_import(module_name: str, class_name: str):
|
51 |
+
"""Lazily import a class from an external module based on the package of the caller."""
|
52 |
|
53 |
+
# Get the caller's module and package
|
54 |
+
import inspect
|
55 |
+
|
56 |
+
caller_frame = inspect.currentframe().f_back
|
57 |
+
module = inspect.getmodule(caller_frame)
|
58 |
+
package = module.__package__ if module else None
|
59 |
+
|
60 |
+
def import_class(*args, **kwargs):
|
61 |
import importlib
|
62 |
|
63 |
# Import the module using importlib
|
64 |
+
module = importlib.import_module(module_name, package=package)
|
65 |
|
66 |
+
# Get the class from the module and instantiate it
|
67 |
+
cls = getattr(module, class_name)
|
68 |
+
return cls(*args, **kwargs)
|
69 |
|
|
|
70 |
return import_class
|
71 |
|
72 |
|
lightrag/llm.py
CHANGED
@@ -64,6 +64,7 @@ async def openai_complete_if_cache(
|
|
64 |
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
65 |
)
|
66 |
kwargs.pop("hashing_kv", None)
|
|
|
67 |
messages = []
|
68 |
if system_prompt:
|
69 |
messages.append({"role": "system", "content": system_prompt})
|
|
|
64 |
AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
|
65 |
)
|
66 |
kwargs.pop("hashing_kv", None)
|
67 |
+
kwargs.pop("keyword_extraction", None)
|
68 |
messages = []
|
69 |
if system_prompt:
|
70 |
messages.append({"role": "system", "content": system_prompt})
|
lightrag/storage.py
CHANGED
@@ -107,10 +107,16 @@ class NanoVectorDBStorage(BaseVectorStorage):
|
|
107 |
embeddings = await f
|
108 |
embeddings_list.append(embeddings)
|
109 |
embeddings = np.concatenate(embeddings_list)
|
110 |
-
|
111 |
-
d
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
async def query(self, query: str, top_k=5):
|
116 |
embedding = await self.embedding_func([query])
|
|
|
107 |
embeddings = await f
|
108 |
embeddings_list.append(embeddings)
|
109 |
embeddings = np.concatenate(embeddings_list)
|
110 |
+
if len(embeddings) == len(list_data):
|
111 |
+
for i, d in enumerate(list_data):
|
112 |
+
d["__vector__"] = embeddings[i]
|
113 |
+
results = self._client.upsert(datas=list_data)
|
114 |
+
return results
|
115 |
+
else:
|
116 |
+
# sometimes the embedding is not returned correctly. just log it.
|
117 |
+
logger.error(
|
118 |
+
f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
|
119 |
+
)
|
120 |
|
121 |
async def query(self, query: str, top_k=5):
|
122 |
embedding = await self.embedding_func([query])
|
lightrag/utils.py
CHANGED
@@ -17,6 +17,17 @@ import tiktoken
|
|
17 |
|
18 |
from lightrag.prompt import PROMPTS
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
ENCODER = None
|
21 |
|
22 |
logger = logging.getLogger("lightrag")
|
@@ -42,9 +53,17 @@ class EmbeddingFunc:
|
|
42 |
embedding_dim: int
|
43 |
max_token_size: int
|
44 |
func: callable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
47 |
-
|
|
|
48 |
|
49 |
|
50 |
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|
|
|
17 |
|
18 |
from lightrag.prompt import PROMPTS
|
19 |
|
20 |
+
|
21 |
+
class UnlimitedSemaphore:
|
22 |
+
"""A context manager that allows unlimited access."""
|
23 |
+
|
24 |
+
async def __aenter__(self):
|
25 |
+
pass
|
26 |
+
|
27 |
+
async def __aexit__(self, exc_type, exc, tb):
|
28 |
+
pass
|
29 |
+
|
30 |
+
|
31 |
ENCODER = None
|
32 |
|
33 |
logger = logging.getLogger("lightrag")
|
|
|
53 |
embedding_dim: int
|
54 |
max_token_size: int
|
55 |
func: callable
|
56 |
+
concurrent_limit: int = 16
|
57 |
+
|
58 |
+
def __post_init__(self):
|
59 |
+
if self.concurrent_limit != 0:
|
60 |
+
self._semaphore = asyncio.Semaphore(self.concurrent_limit)
|
61 |
+
else:
|
62 |
+
self._semaphore = UnlimitedSemaphore()
|
63 |
|
64 |
async def __call__(self, *args, **kwargs) -> np.ndarray:
|
65 |
+
async with self._semaphore:
|
66 |
+
return await self.func(*args, **kwargs)
|
67 |
|
68 |
|
69 |
def locate_json_string_body_from_string(content: str) -> Union[str, None]:
|