zrguo commited on
Commit
f2caa47
·
unverified ·
2 Parent(s): af14936 154f331

Merge pull request #1051 from HKUDS/dev

Browse files

Refactor LightRAG for better code organization

Files changed (3) hide show
  1. lightrag/lightrag.py +39 -85
  2. lightrag/operate.py +87 -0
  3. lightrag/utils.py +49 -0
lightrag/lightrag.py CHANGED
@@ -30,11 +30,10 @@ from .namespace import NameSpace, make_namespace
30
  from .operate import (
31
  chunking_by_token_size,
32
  extract_entities,
33
- extract_keywords_only,
34
  kg_query,
35
- kg_query_with_keywords,
36
  mix_kg_vector_query,
37
  naive_query,
 
38
  )
39
  from .prompt import GRAPH_FIELD_SEP, PROMPTS
40
  from .utils import (
@@ -45,6 +44,9 @@ from .utils import (
45
  encode_string_by_tiktoken,
46
  lazy_external_import,
47
  limit_async_func_call,
 
 
 
48
  logger,
49
  )
50
  from .types import KnowledgeGraph
@@ -309,7 +311,7 @@ class LightRAG:
309
  # Verify storage implementation compatibility
310
  verify_storage_implementation(storage_type, storage_name)
311
  # Check environment variables
312
- # self.check_storage_env_vars(storage_name)
313
 
314
  # Ensure vector_db_storage_cls_kwargs has required fields
315
  self.vector_db_storage_cls_kwargs = {
@@ -536,11 +538,6 @@ class LightRAG:
536
  storage_class = lazy_external_import(import_path, storage_name)
537
  return storage_class
538
 
539
- @staticmethod
540
- def clean_text(text: str) -> str:
541
- """Clean text by removing null bytes (0x00) and whitespace"""
542
- return text.strip().replace("\x00", "")
543
-
544
  def insert(
545
  self,
546
  input: str | list[str],
@@ -602,8 +599,8 @@ class LightRAG:
602
  update_storage = False
603
  try:
604
  # Clean input texts
605
- full_text = self.clean_text(full_text)
606
- text_chunks = [self.clean_text(chunk) for chunk in text_chunks]
607
 
608
  # Process cleaned texts
609
  if doc_id is None:
@@ -682,7 +679,7 @@ class LightRAG:
682
  contents = {id_: doc for id_, doc in zip(ids, input)}
683
  else:
684
  # Clean input text and remove duplicates
685
- input = list(set(self.clean_text(doc) for doc in input))
686
  # Generate contents dict of MD5 hash IDs and documents
687
  contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input}
688
 
@@ -698,7 +695,7 @@ class LightRAG:
698
  new_docs: dict[str, Any] = {
699
  id_: {
700
  "content": content,
701
- "content_summary": self._get_content_summary(content),
702
  "content_length": len(content),
703
  "status": DocStatus.PENDING,
704
  "created_at": datetime.now().isoformat(),
@@ -1063,7 +1060,7 @@ class LightRAG:
1063
  all_chunks_data: dict[str, dict[str, str]] = {}
1064
  chunk_to_source_map: dict[str, str] = {}
1065
  for chunk_data in custom_kg.get("chunks", []):
1066
- chunk_content = self.clean_text(chunk_data["content"])
1067
  source_id = chunk_data["source_id"]
1068
  tokens = len(
1069
  encode_string_by_tiktoken(
@@ -1296,8 +1293,17 @@ class LightRAG:
1296
  self, query: str, prompt: str, param: QueryParam = QueryParam()
1297
  ):
1298
  """
1299
- 1. Extract keywords from the 'query' using new function in operate.py.
1300
- 2. Then run the standard aquery() flow with the final prompt (formatted_question).
 
 
 
 
 
 
 
 
 
1301
  """
1302
  loop = always_get_an_event_loop()
1303
  return loop.run_until_complete(
@@ -1308,66 +1314,29 @@ class LightRAG:
1308
  self, query: str, prompt: str, param: QueryParam = QueryParam()
1309
  ) -> str | AsyncIterator[str]:
1310
  """
1311
- 1. Calls extract_keywords_only to get HL/LL keywords from 'query'.
1312
- 2. Then calls kg_query(...) or naive_query(...), etc. as the main query, while also injecting the newly extracted keywords if needed.
 
 
 
 
 
 
 
1313
  """
1314
- # ---------------------
1315
- # STEP 1: Keyword Extraction
1316
- # ---------------------
1317
- hl_keywords, ll_keywords = await extract_keywords_only(
1318
- text=query,
1319
  param=param,
 
 
 
 
 
1320
  global_config=asdict(self),
1321
- hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1322
  )
1323
 
1324
- param.hl_keywords = hl_keywords
1325
- param.ll_keywords = ll_keywords
1326
-
1327
- # ---------------------
1328
- # STEP 2: Final Query Logic
1329
- # ---------------------
1330
-
1331
- # Create a new string with the prompt and the keywords
1332
- ll_keywords_str = ", ".join(ll_keywords)
1333
- hl_keywords_str = ", ".join(hl_keywords)
1334
- formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
1335
-
1336
- if param.mode in ["local", "global", "hybrid"]:
1337
- response = await kg_query_with_keywords(
1338
- formatted_question,
1339
- self.chunk_entity_relation_graph,
1340
- self.entities_vdb,
1341
- self.relationships_vdb,
1342
- self.text_chunks,
1343
- param,
1344
- asdict(self),
1345
- hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1346
- )
1347
- elif param.mode == "naive":
1348
- response = await naive_query(
1349
- formatted_question,
1350
- self.chunks_vdb,
1351
- self.text_chunks,
1352
- param,
1353
- asdict(self),
1354
- hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1355
- )
1356
- elif param.mode == "mix":
1357
- response = await mix_kg_vector_query(
1358
- formatted_question,
1359
- self.chunk_entity_relation_graph,
1360
- self.entities_vdb,
1361
- self.relationships_vdb,
1362
- self.chunks_vdb,
1363
- self.text_chunks,
1364
- param,
1365
- asdict(self),
1366
- hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
1367
- )
1368
- else:
1369
- raise ValueError(f"Unknown mode {param.mode}")
1370
-
1371
  await self._query_done()
1372
  return response
1373
 
@@ -1465,21 +1434,6 @@ class LightRAG:
1465
  ]
1466
  )
1467
 
1468
- def _get_content_summary(self, content: str, max_length: int = 100) -> str:
1469
- """Get summary of document content
1470
-
1471
- Args:
1472
- content: Original document content
1473
- max_length: Maximum length of summary
1474
-
1475
- Returns:
1476
- Truncated content with ellipsis if needed
1477
- """
1478
- content = content.strip()
1479
- if len(content) <= max_length:
1480
- return content
1481
- return content[:max_length] + "..."
1482
-
1483
  async def get_processing_status(self) -> dict[str, int]:
1484
  """Get current document processing status counts
1485
 
 
30
  from .operate import (
31
  chunking_by_token_size,
32
  extract_entities,
 
33
  kg_query,
 
34
  mix_kg_vector_query,
35
  naive_query,
36
+ query_with_keywords,
37
  )
38
  from .prompt import GRAPH_FIELD_SEP, PROMPTS
39
  from .utils import (
 
44
  encode_string_by_tiktoken,
45
  lazy_external_import,
46
  limit_async_func_call,
47
+ get_content_summary,
48
+ clean_text,
49
+ check_storage_env_vars,
50
  logger,
51
  )
52
  from .types import KnowledgeGraph
 
311
  # Verify storage implementation compatibility
312
  verify_storage_implementation(storage_type, storage_name)
313
  # Check environment variables
314
+ check_storage_env_vars(storage_name)
315
 
316
  # Ensure vector_db_storage_cls_kwargs has required fields
317
  self.vector_db_storage_cls_kwargs = {
 
538
  storage_class = lazy_external_import(import_path, storage_name)
539
  return storage_class
540
 
 
 
 
 
 
541
  def insert(
542
  self,
543
  input: str | list[str],
 
599
  update_storage = False
600
  try:
601
  # Clean input texts
602
+ full_text = clean_text(full_text)
603
+ text_chunks = [clean_text(chunk) for chunk in text_chunks]
604
 
605
  # Process cleaned texts
606
  if doc_id is None:
 
679
  contents = {id_: doc for id_, doc in zip(ids, input)}
680
  else:
681
  # Clean input text and remove duplicates
682
+ input = list(set(clean_text(doc) for doc in input))
683
  # Generate contents dict of MD5 hash IDs and documents
684
  contents = {compute_mdhash_id(doc, prefix="doc-"): doc for doc in input}
685
 
 
695
  new_docs: dict[str, Any] = {
696
  id_: {
697
  "content": content,
698
+ "content_summary": get_content_summary(content),
699
  "content_length": len(content),
700
  "status": DocStatus.PENDING,
701
  "created_at": datetime.now().isoformat(),
 
1060
  all_chunks_data: dict[str, dict[str, str]] = {}
1061
  chunk_to_source_map: dict[str, str] = {}
1062
  for chunk_data in custom_kg.get("chunks", []):
1063
+ chunk_content = clean_text(chunk_data["content"])
1064
  source_id = chunk_data["source_id"]
1065
  tokens = len(
1066
  encode_string_by_tiktoken(
 
1293
  self, query: str, prompt: str, param: QueryParam = QueryParam()
1294
  ):
1295
  """
1296
+ Query with separate keyword extraction step.
1297
+
1298
+ This method extracts keywords from the query first, then uses them for the query.
1299
+
1300
+ Args:
1301
+ query: User query
1302
+ prompt: Additional prompt for the query
1303
+ param: Query parameters
1304
+
1305
+ Returns:
1306
+ Query response
1307
  """
1308
  loop = always_get_an_event_loop()
1309
  return loop.run_until_complete(
 
1314
  self, query: str, prompt: str, param: QueryParam = QueryParam()
1315
  ) -> str | AsyncIterator[str]:
1316
  """
1317
+ Async version of query_with_separate_keyword_extraction.
1318
+
1319
+ Args:
1320
+ query: User query
1321
+ prompt: Additional prompt for the query
1322
+ param: Query parameters
1323
+
1324
+ Returns:
1325
+ Query response or async iterator
1326
  """
1327
+ response = await query_with_keywords(
1328
+ query=query,
1329
+ prompt=prompt,
 
 
1330
  param=param,
1331
+ knowledge_graph_inst=self.chunk_entity_relation_graph,
1332
+ entities_vdb=self.entities_vdb,
1333
+ relationships_vdb=self.relationships_vdb,
1334
+ chunks_vdb=self.chunks_vdb,
1335
+ text_chunks_db=self.text_chunks,
1336
  global_config=asdict(self),
1337
+ hashing_kv=self.llm_response_cache,
1338
  )
1339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1340
  await self._query_done()
1341
  return response
1342
 
 
1434
  ]
1435
  )
1436
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1437
  async def get_processing_status(self) -> dict[str, int]:
1438
  """Get current document processing status counts
1439
 
lightrag/operate.py CHANGED
@@ -1916,3 +1916,90 @@ async def kg_query_with_keywords(
1916
  )
1917
 
1918
  return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1916
  )
1917
 
1918
  return response
1919
+
1920
+
1921
+ async def query_with_keywords(
1922
+ query: str,
1923
+ prompt: str,
1924
+ param: QueryParam,
1925
+ knowledge_graph_inst: BaseGraphStorage,
1926
+ entities_vdb: BaseVectorStorage,
1927
+ relationships_vdb: BaseVectorStorage,
1928
+ chunks_vdb: BaseVectorStorage,
1929
+ text_chunks_db: BaseKVStorage,
1930
+ global_config: dict[str, str],
1931
+ hashing_kv: BaseKVStorage | None = None,
1932
+ ) -> str | AsyncIterator[str]:
1933
+ """
1934
+ Extract keywords from the query and then use them for retrieving information.
1935
+
1936
+ 1. Extracts high-level and low-level keywords from the query
1937
+ 2. Formats the query with the extracted keywords and prompt
1938
+ 3. Uses the appropriate query method based on param.mode
1939
+
1940
+ Args:
1941
+ query: The user's query
1942
+ prompt: Additional prompt to prepend to the query
1943
+ param: Query parameters
1944
+ knowledge_graph_inst: Knowledge graph storage
1945
+ entities_vdb: Entities vector database
1946
+ relationships_vdb: Relationships vector database
1947
+ chunks_vdb: Document chunks vector database
1948
+ text_chunks_db: Text chunks storage
1949
+ global_config: Global configuration
1950
+ hashing_kv: Cache storage
1951
+
1952
+ Returns:
1953
+ Query response or async iterator
1954
+ """
1955
+ # Extract keywords
1956
+ hl_keywords, ll_keywords = await extract_keywords_only(
1957
+ text=query,
1958
+ param=param,
1959
+ global_config=global_config,
1960
+ hashing_kv=hashing_kv,
1961
+ )
1962
+
1963
+ param.hl_keywords = hl_keywords
1964
+ param.ll_keywords = ll_keywords
1965
+
1966
+ # Create a new string with the prompt and the keywords
1967
+ ll_keywords_str = ", ".join(ll_keywords)
1968
+ hl_keywords_str = ", ".join(hl_keywords)
1969
+ formatted_question = f"{prompt}\n\n### Keywords:\nHigh-level: {hl_keywords_str}\nLow-level: {ll_keywords_str}\n\n### Query:\n{query}"
1970
+
1971
+ # Use appropriate query method based on mode
1972
+ if param.mode in ["local", "global", "hybrid"]:
1973
+ return await kg_query_with_keywords(
1974
+ formatted_question,
1975
+ knowledge_graph_inst,
1976
+ entities_vdb,
1977
+ relationships_vdb,
1978
+ text_chunks_db,
1979
+ param,
1980
+ global_config,
1981
+ hashing_kv=hashing_kv,
1982
+ )
1983
+ elif param.mode == "naive":
1984
+ return await naive_query(
1985
+ formatted_question,
1986
+ chunks_vdb,
1987
+ text_chunks_db,
1988
+ param,
1989
+ global_config,
1990
+ hashing_kv=hashing_kv,
1991
+ )
1992
+ elif param.mode == "mix":
1993
+ return await mix_kg_vector_query(
1994
+ formatted_question,
1995
+ knowledge_graph_inst,
1996
+ entities_vdb,
1997
+ relationships_vdb,
1998
+ chunks_vdb,
1999
+ text_chunks_db,
2000
+ param,
2001
+ global_config,
2002
+ hashing_kv=hashing_kv,
2003
+ )
2004
+ else:
2005
+ raise ValueError(f"Unknown mode {param.mode}")
lightrag/utils.py CHANGED
@@ -890,3 +890,52 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
890
  return cls(*args, **kwargs)
891
 
892
  return import_class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890
  return cls(*args, **kwargs)
891
 
892
  return import_class
893
+
894
+
895
+ def get_content_summary(content: str, max_length: int = 100) -> str:
896
+ """Get summary of document content
897
+
898
+ Args:
899
+ content: Original document content
900
+ max_length: Maximum length of summary
901
+
902
+ Returns:
903
+ Truncated content with ellipsis if needed
904
+ """
905
+ content = content.strip()
906
+ if len(content) <= max_length:
907
+ return content
908
+ return content[:max_length] + "..."
909
+
910
+
911
+ def clean_text(text: str) -> str:
912
+ """Clean text by removing null bytes (0x00) and whitespace
913
+
914
+ Args:
915
+ text: Input text to clean
916
+
917
+ Returns:
918
+ Cleaned text
919
+ """
920
+ return text.strip().replace("\x00", "")
921
+
922
+
923
+ def check_storage_env_vars(storage_name: str) -> None:
924
+ """Check if all required environment variables for storage implementation exist
925
+
926
+ Args:
927
+ storage_name: Storage implementation name
928
+
929
+ Raises:
930
+ ValueError: If required environment variables are missing
931
+ """
932
+ from lightrag.kg import STORAGE_ENV_REQUIREMENTS
933
+
934
+ required_vars = STORAGE_ENV_REQUIREMENTS.get(storage_name, [])
935
+ missing_vars = [var for var in required_vars if var not in os.environ]
936
+
937
+ if missing_vars:
938
+ raise ValueError(
939
+ f"Storage implementation '{storage_name}' requires the following "
940
+ f"environment variables: {', '.join(missing_vars)}"
941
+ )