Daniel.y commited on
Commit
23ed632
·
unverified ·
2 Parent(s): ac06884 826c791

Merge pull request #1328 from danielaskdd/main

Browse files

Fix LLM cache now work for nodes and edges merging

lightrag/api/__init__.py CHANGED
@@ -1 +1 @@
1
- __api_version__ = "0142"
 
1
+ __api_version__ = "0143"
lightrag/kg/json_doc_status_impl.py CHANGED
@@ -116,7 +116,7 @@ class JsonDocStatusStorage(DocStatusStorage):
116
  """
117
  if not data:
118
  return
119
- logger.info(f"Inserting {len(data)} records to {self.namespace}")
120
  async with self._storage_lock:
121
  self._data.update(data)
122
  await set_all_update_flags(self.namespace)
 
116
  """
117
  if not data:
118
  return
119
+ logger.debug(f"Inserting {len(data)} records to {self.namespace}")
120
  async with self._storage_lock:
121
  self._data.update(data)
122
  await set_all_update_flags(self.namespace)
lightrag/kg/json_kv_impl.py CHANGED
@@ -121,7 +121,7 @@ class JsonKVStorage(BaseKVStorage):
121
  """
122
  if not data:
123
  return
124
- logger.info(f"Inserting {len(data)} records to {self.namespace}")
125
  async with self._storage_lock:
126
  self._data.update(data)
127
  await set_all_update_flags(self.namespace)
 
121
  """
122
  if not data:
123
  return
124
+ logger.debug(f"Inserting {len(data)} records to {self.namespace}")
125
  async with self._storage_lock:
126
  self._data.update(data)
127
  await set_all_update_flags(self.namespace)
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -85,7 +85,7 @@ class NanoVectorDBStorage(BaseVectorStorage):
85
  KG-storage-log should be used to avoid data corruption
86
  """
87
 
88
- logger.info(f"Inserting {len(data)} to {self.namespace}")
89
  if not data:
90
  return
91
 
 
85
  KG-storage-log should be used to avoid data corruption
86
  """
87
 
88
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
89
  if not data:
90
  return
91
 
lightrag/kg/networkx_impl.py CHANGED
@@ -392,7 +392,7 @@ class NetworkXStorage(BaseGraphStorage):
392
  # Check if storage was updated by another process
393
  if self.storage_updated.value:
394
  # Storage was updated by another process, reload data instead of saving
395
- logger.warning(
396
  f"Graph for {self.namespace} was updated by another process, reloading..."
397
  )
398
  self._graph = (
 
392
  # Check if storage was updated by another process
393
  if self.storage_updated.value:
394
  # Storage was updated by another process, reload data instead of saving
395
+ logger.info(
396
  f"Graph for {self.namespace} was updated by another process, reloading..."
397
  )
398
  self._graph = (
lightrag/kg/postgres_impl.py CHANGED
@@ -361,7 +361,7 @@ class PGKVStorage(BaseKVStorage):
361
 
362
  ################ INSERT METHODS ################
363
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
364
- logger.info(f"Inserting {len(data)} to {self.namespace}")
365
  if not data:
366
  return
367
 
@@ -560,7 +560,7 @@ class PGVectorStorage(BaseVectorStorage):
560
  return upsert_sql, data
561
 
562
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
563
- logger.info(f"Inserting {len(data)} to {self.namespace}")
564
  if not data:
565
  return
566
 
@@ -949,7 +949,7 @@ class PGDocStatusStorage(DocStatusStorage):
949
  Args:
950
  data: dictionary of document IDs and their status data
951
  """
952
- logger.info(f"Inserting {len(data)} to {self.namespace}")
953
  if not data:
954
  return
955
 
 
361
 
362
  ################ INSERT METHODS ################
363
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
364
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
365
  if not data:
366
  return
367
 
 
560
  return upsert_sql, data
561
 
562
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
563
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
564
  if not data:
565
  return
566
 
 
949
  Args:
950
  data: dictionary of document IDs and their status data
951
  """
952
+ logger.debug(f"Inserting {len(data)} to {self.namespace}")
953
  if not data:
954
  return
955
 
lightrag/operate.py CHANGED
@@ -24,8 +24,8 @@ from .utils import (
24
  handle_cache,
25
  save_to_cache,
26
  CacheData,
27
- statistic_data,
28
  get_conversation_turns,
 
29
  )
30
  from .base import (
31
  BaseGraphStorage,
@@ -106,6 +106,9 @@ async def _handle_entity_relation_summary(
106
  entity_or_relation_name: str,
107
  description: str,
108
  global_config: dict,
 
 
 
109
  ) -> str:
110
  """Handle entity relation summary
111
  For each entity or relation, input is the combined description of already existing description and new description.
@@ -122,6 +125,7 @@ async def _handle_entity_relation_summary(
122
  tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
123
  if len(tokens) < summary_max_tokens: # No need for summary
124
  return description
 
125
  prompt_template = PROMPTS["summarize_entity_descriptions"]
126
  use_description = decode_tokens_by_tiktoken(
127
  tokens[:llm_max_tokens], model_name=tiktoken_model_name
@@ -133,7 +137,23 @@ async def _handle_entity_relation_summary(
133
  )
134
  use_prompt = prompt_template.format(**context_base)
135
  logger.debug(f"Trigger summary: {entity_or_relation_name}")
136
- summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  return summary
138
 
139
 
@@ -212,6 +232,9 @@ async def _merge_nodes_then_upsert(
212
  nodes_data: list[dict],
213
  knowledge_graph_inst: BaseGraphStorage,
214
  global_config: dict,
 
 
 
215
  ):
216
  """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
217
  already_entity_types = []
@@ -221,6 +244,14 @@ async def _merge_nodes_then_upsert(
221
 
222
  already_node = await knowledge_graph_inst.get_node(entity_name)
223
  if already_node is not None:
 
 
 
 
 
 
 
 
224
  already_entity_types.append(already_node["entity_type"])
225
  already_source_ids.extend(
226
  split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
@@ -249,7 +280,12 @@ async def _merge_nodes_then_upsert(
249
 
250
  logger.debug(f"file_path: {file_path}")
251
  description = await _handle_entity_relation_summary(
252
- entity_name, description, global_config
 
 
 
 
 
253
  )
254
  node_data = dict(
255
  entity_id=entity_name,
@@ -272,6 +308,9 @@ async def _merge_edges_then_upsert(
272
  edges_data: list[dict],
273
  knowledge_graph_inst: BaseGraphStorage,
274
  global_config: dict,
 
 
 
275
  ):
276
  already_weights = []
277
  already_source_ids = []
@@ -280,6 +319,14 @@ async def _merge_edges_then_upsert(
280
  already_file_paths = []
281
 
282
  if await knowledge_graph_inst.has_edge(src_id, tgt_id):
 
 
 
 
 
 
 
 
283
  already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
284
  # Handle the case where get_edge returns None or missing fields
285
  if already_edge:
@@ -358,7 +405,12 @@ async def _merge_edges_then_upsert(
358
  },
359
  )
360
  description = await _handle_entity_relation_summary(
361
- f"({src_id}, {tgt_id})", description, global_config
 
 
 
 
 
362
  )
363
  await knowledge_graph_inst.upsert_edge(
364
  src_id,
@@ -396,9 +448,6 @@ async def extract_entities(
396
  ) -> None:
397
  use_llm_func: callable = global_config["llm_model_func"]
398
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
399
- enable_llm_cache_for_entity_extract: bool = global_config[
400
- "enable_llm_cache_for_entity_extract"
401
- ]
402
 
403
  ordered_chunks = list(chunks.items())
404
  # add language and example number params to prompt
@@ -449,51 +498,7 @@ async def extract_entities(
449
 
450
  graph_db_lock = get_graph_db_lock(enable_logging=False)
451
 
452
- async def _user_llm_func_with_cache(
453
- input_text: str, history_messages: list[dict[str, str]] = None
454
- ) -> str:
455
- if enable_llm_cache_for_entity_extract and llm_response_cache:
456
- if history_messages:
457
- history = json.dumps(history_messages, ensure_ascii=False)
458
- _prompt = history + "\n" + input_text
459
- else:
460
- _prompt = input_text
461
-
462
- # TODO: add cache_type="extract"
463
- arg_hash = compute_args_hash(_prompt)
464
- cached_return, _1, _2, _3 = await handle_cache(
465
- llm_response_cache,
466
- arg_hash,
467
- _prompt,
468
- "default",
469
- cache_type="extract",
470
- )
471
- if cached_return:
472
- logger.debug(f"Found cache for {arg_hash}")
473
- statistic_data["llm_cache"] += 1
474
- return cached_return
475
- statistic_data["llm_call"] += 1
476
- if history_messages:
477
- res: str = await use_llm_func(
478
- input_text, history_messages=history_messages
479
- )
480
- else:
481
- res: str = await use_llm_func(input_text)
482
- await save_to_cache(
483
- llm_response_cache,
484
- CacheData(
485
- args_hash=arg_hash,
486
- content=res,
487
- prompt=_prompt,
488
- cache_type="extract",
489
- ),
490
- )
491
- return res
492
-
493
- if history_messages:
494
- return await use_llm_func(input_text, history_messages=history_messages)
495
- else:
496
- return await use_llm_func(input_text)
497
 
498
  async def _process_extraction_result(
499
  result: str, chunk_key: str, file_path: str = "unknown_source"
@@ -558,7 +563,12 @@ async def extract_entities(
558
  **context_base, input_text="{input_text}"
559
  ).format(**context_base, input_text=content)
560
 
561
- final_result = await _user_llm_func_with_cache(hint_prompt)
 
 
 
 
 
562
  history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
563
 
564
  # Process initial extraction with file path
@@ -568,8 +578,12 @@ async def extract_entities(
568
 
569
  # Process additional gleaning results
570
  for now_glean_index in range(entity_extract_max_gleaning):
571
- glean_result = await _user_llm_func_with_cache(
572
- continue_prompt, history_messages=history
 
 
 
 
573
  )
574
 
575
  history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
@@ -588,8 +602,12 @@ async def extract_entities(
588
  if now_glean_index == entity_extract_max_gleaning - 1:
589
  break
590
 
591
- if_loop_result: str = await _user_llm_func_with_cache(
592
- if_loop_prompt, history_messages=history
 
 
 
 
593
  )
594
  if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
595
  if if_loop_result != "yes":
@@ -613,7 +631,13 @@ async def extract_entities(
613
  # Process and update entities
614
  for entity_name, entities in maybe_nodes.items():
615
  entity_data = await _merge_nodes_then_upsert(
616
- entity_name, entities, knowledge_graph_inst, global_config
 
 
 
 
 
 
617
  )
618
  chunk_entities_data.append(entity_data)
619
 
@@ -627,6 +651,9 @@ async def extract_entities(
627
  edges,
628
  knowledge_graph_inst,
629
  global_config,
 
 
 
630
  )
631
  chunk_relationships_data.append(edge_data)
632
 
 
24
  handle_cache,
25
  save_to_cache,
26
  CacheData,
 
27
  get_conversation_turns,
28
+ use_llm_func_with_cache,
29
  )
30
  from .base import (
31
  BaseGraphStorage,
 
106
  entity_or_relation_name: str,
107
  description: str,
108
  global_config: dict,
109
+ pipeline_status: dict = None,
110
+ pipeline_status_lock=None,
111
+ llm_response_cache: BaseKVStorage | None = None,
112
  ) -> str:
113
  """Handle entity relation summary
114
  For each entity or relation, input is the combined description of already existing description and new description.
 
125
  tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
126
  if len(tokens) < summary_max_tokens: # No need for summary
127
  return description
128
+
129
  prompt_template = PROMPTS["summarize_entity_descriptions"]
130
  use_description = decode_tokens_by_tiktoken(
131
  tokens[:llm_max_tokens], model_name=tiktoken_model_name
 
137
  )
138
  use_prompt = prompt_template.format(**context_base)
139
  logger.debug(f"Trigger summary: {entity_or_relation_name}")
140
+
141
+ # Update pipeline status when LLM summary is needed
142
+ status_message = "Use LLM to re-summary description..."
143
+ logger.info(status_message)
144
+ if pipeline_status is not None and pipeline_status_lock is not None:
145
+ async with pipeline_status_lock:
146
+ pipeline_status["latest_message"] = status_message
147
+ pipeline_status["history_messages"].append(status_message)
148
+
149
+ # Use LLM function with cache
150
+ summary = await use_llm_func_with_cache(
151
+ use_prompt,
152
+ use_llm_func,
153
+ llm_response_cache=llm_response_cache,
154
+ max_tokens=summary_max_tokens,
155
+ cache_type="extract",
156
+ )
157
  return summary
158
 
159
 
 
232
  nodes_data: list[dict],
233
  knowledge_graph_inst: BaseGraphStorage,
234
  global_config: dict,
235
+ pipeline_status: dict = None,
236
+ pipeline_status_lock=None,
237
+ llm_response_cache: BaseKVStorage | None = None,
238
  ):
239
  """Get existing nodes from knowledge graph use name,if exists, merge data, else create, then upsert."""
240
  already_entity_types = []
 
244
 
245
  already_node = await knowledge_graph_inst.get_node(entity_name)
246
  if already_node is not None:
247
+ # Update pipeline status when a node that needs merging is found
248
+ status_message = f"Merging entity: {entity_name}"
249
+ logger.info(status_message)
250
+ if pipeline_status is not None and pipeline_status_lock is not None:
251
+ async with pipeline_status_lock:
252
+ pipeline_status["latest_message"] = status_message
253
+ pipeline_status["history_messages"].append(status_message)
254
+
255
  already_entity_types.append(already_node["entity_type"])
256
  already_source_ids.extend(
257
  split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
 
280
 
281
  logger.debug(f"file_path: {file_path}")
282
  description = await _handle_entity_relation_summary(
283
+ entity_name,
284
+ description,
285
+ global_config,
286
+ pipeline_status,
287
+ pipeline_status_lock,
288
+ llm_response_cache,
289
  )
290
  node_data = dict(
291
  entity_id=entity_name,
 
308
  edges_data: list[dict],
309
  knowledge_graph_inst: BaseGraphStorage,
310
  global_config: dict,
311
+ pipeline_status: dict = None,
312
+ pipeline_status_lock=None,
313
+ llm_response_cache: BaseKVStorage | None = None,
314
  ):
315
  already_weights = []
316
  already_source_ids = []
 
319
  already_file_paths = []
320
 
321
  if await knowledge_graph_inst.has_edge(src_id, tgt_id):
322
+ # Update pipeline status when an edge that needs merging is found
323
+ status_message = f"Merging edge::: {src_id} - {tgt_id}"
324
+ logger.info(status_message)
325
+ if pipeline_status is not None and pipeline_status_lock is not None:
326
+ async with pipeline_status_lock:
327
+ pipeline_status["latest_message"] = status_message
328
+ pipeline_status["history_messages"].append(status_message)
329
+
330
  already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
331
  # Handle the case where get_edge returns None or missing fields
332
  if already_edge:
 
405
  },
406
  )
407
  description = await _handle_entity_relation_summary(
408
+ f"({src_id}, {tgt_id})",
409
+ description,
410
+ global_config,
411
+ pipeline_status,
412
+ pipeline_status_lock,
413
+ llm_response_cache,
414
  )
415
  await knowledge_graph_inst.upsert_edge(
416
  src_id,
 
448
  ) -> None:
449
  use_llm_func: callable = global_config["llm_model_func"]
450
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
 
 
 
451
 
452
  ordered_chunks = list(chunks.items())
453
  # add language and example number params to prompt
 
498
 
499
  graph_db_lock = get_graph_db_lock(enable_logging=False)
500
 
501
+ # Use the global use_llm_func_with_cache function from utils.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
 
503
  async def _process_extraction_result(
504
  result: str, chunk_key: str, file_path: str = "unknown_source"
 
563
  **context_base, input_text="{input_text}"
564
  ).format(**context_base, input_text=content)
565
 
566
+ final_result = await use_llm_func_with_cache(
567
+ hint_prompt,
568
+ use_llm_func,
569
+ llm_response_cache=llm_response_cache,
570
+ cache_type="extract",
571
+ )
572
  history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
573
 
574
  # Process initial extraction with file path
 
578
 
579
  # Process additional gleaning results
580
  for now_glean_index in range(entity_extract_max_gleaning):
581
+ glean_result = await use_llm_func_with_cache(
582
+ continue_prompt,
583
+ use_llm_func,
584
+ llm_response_cache=llm_response_cache,
585
+ history_messages=history,
586
+ cache_type="extract",
587
  )
588
 
589
  history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
 
602
  if now_glean_index == entity_extract_max_gleaning - 1:
603
  break
604
 
605
+ if_loop_result: str = await use_llm_func_with_cache(
606
+ if_loop_prompt,
607
+ use_llm_func,
608
+ llm_response_cache=llm_response_cache,
609
+ history_messages=history,
610
+ cache_type="extract",
611
  )
612
  if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
613
  if if_loop_result != "yes":
 
631
  # Process and update entities
632
  for entity_name, entities in maybe_nodes.items():
633
  entity_data = await _merge_nodes_then_upsert(
634
+ entity_name,
635
+ entities,
636
+ knowledge_graph_inst,
637
+ global_config,
638
+ pipeline_status,
639
+ pipeline_status_lock,
640
+ llm_response_cache,
641
  )
642
  chunk_entities_data.append(entity_data)
643
 
 
651
  edges,
652
  knowledge_graph_inst,
653
  global_config,
654
+ pipeline_status,
655
+ pipeline_status_lock,
656
+ llm_response_cache,
657
  )
658
  chunk_relationships_data.append(edge_data)
659
 
lightrag/utils.py CHANGED
@@ -12,13 +12,17 @@ import re
12
  from dataclasses import dataclass
13
  from functools import wraps
14
  from hashlib import md5
15
- from typing import Any, Callable
16
  import xml.etree.ElementTree as ET
17
  import numpy as np
18
  import tiktoken
19
  from lightrag.prompt import PROMPTS
20
  from dotenv import load_dotenv
21
 
 
 
 
 
22
  # use the .env that is inside the current folder
23
  # allows to use different .env file for each lightrag instance
24
  # the OS environment variables take precedence over the .env file
@@ -908,6 +912,84 @@ def lazy_external_import(module_name: str, class_name: str) -> Callable[..., Any
908
  return import_class
909
 
910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
  def get_content_summary(content: str, max_length: int = 250) -> str:
912
  """Get summary of document content
913
 
 
12
  from dataclasses import dataclass
13
  from functools import wraps
14
  from hashlib import md5
15
+ from typing import Any, Callable, TYPE_CHECKING
16
  import xml.etree.ElementTree as ET
17
  import numpy as np
18
  import tiktoken
19
  from lightrag.prompt import PROMPTS
20
  from dotenv import load_dotenv
21
 
22
+ # Use TYPE_CHECKING to avoid circular imports
23
+ if TYPE_CHECKING:
24
+ from lightrag.base import BaseKVStorage
25
+
26
  # use the .env that is inside the current folder
27
  # allows to use different .env file for each lightrag instance
28
  # the OS environment variables take precedence over the .env file
 
912
  return import_class
913
 
914
 
915
+ async def use_llm_func_with_cache(
916
+ input_text: str,
917
+ use_llm_func: callable,
918
+ llm_response_cache: "BaseKVStorage | None" = None,
919
+ max_tokens: int = None,
920
+ history_messages: list[dict[str, str]] = None,
921
+ cache_type: str = "extract",
922
+ ) -> str:
923
+ """Call LLM function with cache support
924
+
925
+ If cache is available and enabled (determined by handle_cache based on mode),
926
+ retrieve result from cache; otherwise call LLM function and save result to cache.
927
+
928
+ Args:
929
+ input_text: Input text to send to LLM
930
+ use_llm_func: LLM function to call
931
+ llm_response_cache: Cache storage instance
932
+ max_tokens: Maximum tokens for generation
933
+ history_messages: History messages list
934
+ cache_type: Type of cache
935
+
936
+ Returns:
937
+ LLM response text
938
+ """
939
+ if llm_response_cache:
940
+ if history_messages:
941
+ history = json.dumps(history_messages, ensure_ascii=False)
942
+ _prompt = history + "\n" + input_text
943
+ else:
944
+ _prompt = input_text
945
+
946
+ arg_hash = compute_args_hash(_prompt)
947
+ cached_return, _1, _2, _3 = await handle_cache(
948
+ llm_response_cache,
949
+ arg_hash,
950
+ _prompt,
951
+ "default",
952
+ cache_type=cache_type,
953
+ )
954
+ if cached_return:
955
+ logger.debug(f"Found cache for {arg_hash}")
956
+ statistic_data["llm_cache"] += 1
957
+ return cached_return
958
+ statistic_data["llm_call"] += 1
959
+
960
+ # Call LLM
961
+ kwargs = {}
962
+ if history_messages:
963
+ kwargs["history_messages"] = history_messages
964
+ if max_tokens is not None:
965
+ kwargs["max_tokens"] = max_tokens
966
+
967
+ res: str = await use_llm_func(input_text, **kwargs)
968
+
969
+ # Save to cache
970
+ logger.info(f"Saving LLM cache for {arg_hash}")
971
+ await save_to_cache(
972
+ llm_response_cache,
973
+ CacheData(
974
+ args_hash=arg_hash,
975
+ content=res,
976
+ prompt=_prompt,
977
+ cache_type=cache_type,
978
+ ),
979
+ )
980
+ return res
981
+
982
+ # When cache is disabled, directly call LLM
983
+ kwargs = {}
984
+ if history_messages:
985
+ kwargs["history_messages"] = history_messages
986
+ if max_tokens is not None:
987
+ kwargs["max_tokens"] = max_tokens
988
+
989
+ logger.info(f"Call LLM function with query text lenght: {len(input_text)}")
990
+ return await use_llm_func(input_text, **kwargs)
991
+
992
+
993
  def get_content_summary(content: str, max_length: int = 250) -> str:
994
  """Get summary of document content
995