zrguo commited on
Commit
492af66
·
unverified ·
2 Parent(s): f4a8f96 76b6f93

Merge pull request #444 from davidleon/fix/lazy_import

Browse files
lightrag/lightrag.py CHANGED
@@ -48,18 +48,25 @@ from .storage import (
48
 
49
 
50
  def lazy_external_import(module_name: str, class_name: str):
51
- """Lazily import an external module and return a class from it."""
52
 
53
- def import_class():
 
 
 
 
 
 
 
54
  import importlib
55
 
56
  # Import the module using importlib
57
- module = importlib.import_module(module_name)
58
 
59
- # Get the class from the module
60
- return getattr(module, class_name)
 
61
 
62
- # Return the import_class function itself, not its result
63
  return import_class
64
 
65
 
 
48
 
49
 
50
  def lazy_external_import(module_name: str, class_name: str):
51
+ """Lazily import a class from an external module based on the package of the caller."""
52
 
53
+ # Get the caller's module and package
54
+ import inspect
55
+
56
+ caller_frame = inspect.currentframe().f_back
57
+ module = inspect.getmodule(caller_frame)
58
+ package = module.__package__ if module else None
59
+
60
+ def import_class(*args, **kwargs):
61
  import importlib
62
 
63
  # Import the module using importlib
64
+ module = importlib.import_module(module_name, package=package)
65
 
66
+ # Get the class from the module and instantiate it
67
+ cls = getattr(module, class_name)
68
+ return cls(*args, **kwargs)
69
 
 
70
  return import_class
71
 
72
 
lightrag/llm.py CHANGED
@@ -64,6 +64,7 @@ async def openai_complete_if_cache(
64
  AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
65
  )
66
  kwargs.pop("hashing_kv", None)
 
67
  messages = []
68
  if system_prompt:
69
  messages.append({"role": "system", "content": system_prompt})
 
64
  AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
65
  )
66
  kwargs.pop("hashing_kv", None)
67
+ kwargs.pop("keyword_extraction", None)
68
  messages = []
69
  if system_prompt:
70
  messages.append({"role": "system", "content": system_prompt})
lightrag/storage.py CHANGED
@@ -107,10 +107,16 @@ class NanoVectorDBStorage(BaseVectorStorage):
107
  embeddings = await f
108
  embeddings_list.append(embeddings)
109
  embeddings = np.concatenate(embeddings_list)
110
- for i, d in enumerate(list_data):
111
- d["__vector__"] = embeddings[i]
112
- results = self._client.upsert(datas=list_data)
113
- return results
 
 
 
 
 
 
114
 
115
  async def query(self, query: str, top_k=5):
116
  embedding = await self.embedding_func([query])
 
107
  embeddings = await f
108
  embeddings_list.append(embeddings)
109
  embeddings = np.concatenate(embeddings_list)
110
+ if len(embeddings) == len(list_data):
111
+ for i, d in enumerate(list_data):
112
+ d["__vector__"] = embeddings[i]
113
+ results = self._client.upsert(datas=list_data)
114
+ return results
115
+ else:
116
+ # sometimes the embedding is not returned correctly. just log it.
117
+ logger.error(
118
+ f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
119
+ )
120
 
121
  async def query(self, query: str, top_k=5):
122
  embedding = await self.embedding_func([query])
lightrag/utils.py CHANGED
@@ -17,6 +17,17 @@ import tiktoken
17
 
18
  from lightrag.prompt import PROMPTS
19
 
 
 
 
 
 
 
 
 
 
 
 
20
  ENCODER = None
21
 
22
  logger = logging.getLogger("lightrag")
@@ -42,9 +53,17 @@ class EmbeddingFunc:
42
  embedding_dim: int
43
  max_token_size: int
44
  func: callable
 
 
 
 
 
 
 
45
 
46
  async def __call__(self, *args, **kwargs) -> np.ndarray:
47
- return await self.func(*args, **kwargs)
 
48
 
49
 
50
  def locate_json_string_body_from_string(content: str) -> Union[str, None]:
 
17
 
18
  from lightrag.prompt import PROMPTS
19
 
20
+
21
+ class UnlimitedSemaphore:
22
+ """A context manager that allows unlimited access."""
23
+
24
+ async def __aenter__(self):
25
+ pass
26
+
27
+ async def __aexit__(self, exc_type, exc, tb):
28
+ pass
29
+
30
+
31
  ENCODER = None
32
 
33
  logger = logging.getLogger("lightrag")
 
53
  embedding_dim: int
54
  max_token_size: int
55
  func: callable
56
+ concurrent_limit: int = 16
57
+
58
+ def __post_init__(self):
59
+ if self.concurrent_limit != 0:
60
+ self._semaphore = asyncio.Semaphore(self.concurrent_limit)
61
+ else:
62
+ self._semaphore = UnlimitedSemaphore()
63
 
64
  async def __call__(self, *args, **kwargs) -> np.ndarray:
65
+ async with self._semaphore:
66
+ return await self.func(*args, **kwargs)
67
 
68
 
69
  def locate_json_string_body_from_string(content: str) -> Union[str, None]: