zrguo commited on
Commit
1fbf326
·
unverified ·
2 Parent(s): 66250ac 9b78d7d

Merge pull request #1036 from danielaskdd/neo4j-add-min-degree

Browse files

Refactoring Neo4j implementation and fixing storage init problem for Gunicorn

lightrag/api/lightrag_server.py CHANGED
@@ -50,9 +50,6 @@ from .auth import auth_handler
50
  # This update allows the user to put a different.env file for each lightrag folder
51
  load_dotenv(".env", override=True)
52
 
53
- # Read entity extraction cache config
54
- enable_llm_cache = os.getenv("ENABLE_LLM_CACHE_FOR_EXTRACT", "false").lower() == "true"
55
-
56
  # Initialize config parser
57
  config = configparser.ConfigParser()
58
  config.read("config.ini")
@@ -144,23 +141,25 @@ def create_app(args):
144
  try:
145
  # Initialize database connections
146
  await rag.initialize_storages()
147
- await initialize_pipeline_status()
148
 
149
- # Auto scan documents if enabled
150
- if args.auto_scan_at_startup:
151
- # Check if a task is already running (with lock protection)
152
- pipeline_status = await get_namespace_data("pipeline_status")
153
- should_start_task = False
154
- async with get_pipeline_status_lock():
155
- if not pipeline_status.get("busy", False):
156
- should_start_task = True
157
- # Only start the task if no other task is running
158
- if should_start_task:
159
- # Create background task
160
- task = asyncio.create_task(run_scanning_process(rag, doc_manager))
161
- app.state.background_tasks.add(task)
162
- task.add_done_callback(app.state.background_tasks.discard)
163
- logger.info("Auto scan task started at startup.")
 
 
 
164
 
165
  ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
166
 
@@ -326,7 +325,7 @@ def create_app(args):
326
  vector_db_storage_cls_kwargs={
327
  "cosine_better_than_threshold": args.cosine_threshold
328
  },
329
- enable_llm_cache_for_entity_extract=enable_llm_cache, # Read from environment variable
330
  embedding_cache_config={
331
  "enabled": True,
332
  "similarity_threshold": 0.95,
@@ -355,7 +354,7 @@ def create_app(args):
355
  vector_db_storage_cls_kwargs={
356
  "cosine_better_than_threshold": args.cosine_threshold
357
  },
358
- enable_llm_cache_for_entity_extract=enable_llm_cache, # Read from environment variable
359
  embedding_cache_config={
360
  "enabled": True,
361
  "similarity_threshold": 0.95,
@@ -419,6 +418,7 @@ def create_app(args):
419
  "doc_status_storage": args.doc_status_storage,
420
  "graph_storage": args.graph_storage,
421
  "vector_storage": args.vector_storage,
 
422
  },
423
  "update_status": update_status,
424
  }
 
50
  # This update allows the user to put a different.env file for each lightrag folder
51
  load_dotenv(".env", override=True)
52
 
 
 
 
53
  # Initialize config parser
54
  config = configparser.ConfigParser()
55
  config.read("config.ini")
 
141
  try:
142
  # Initialize database connections
143
  await rag.initialize_storages()
 
144
 
145
+ await initialize_pipeline_status()
146
+ pipeline_status = await get_namespace_data("pipeline_status")
147
+
148
+ should_start_autoscan = False
149
+ async with get_pipeline_status_lock():
150
+ # Auto scan documents if enabled
151
+ if args.auto_scan_at_startup:
152
+ if not pipeline_status.get("autoscanned", False):
153
+ pipeline_status["autoscanned"] = True
154
+ should_start_autoscan = True
155
+
156
+ # Only run auto scan when no other process started it first
157
+ if should_start_autoscan:
158
+ # Create background task
159
+ task = asyncio.create_task(run_scanning_process(rag, doc_manager))
160
+ app.state.background_tasks.add(task)
161
+ task.add_done_callback(app.state.background_tasks.discard)
162
+ logger.info(f"Process {os.getpid()} auto scan task started at startup.")
163
 
164
  ASCIIColors.green("\nServer is ready to accept connections! 🚀\n")
165
 
 
325
  vector_db_storage_cls_kwargs={
326
  "cosine_better_than_threshold": args.cosine_threshold
327
  },
328
+ enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
329
  embedding_cache_config={
330
  "enabled": True,
331
  "similarity_threshold": 0.95,
 
354
  vector_db_storage_cls_kwargs={
355
  "cosine_better_than_threshold": args.cosine_threshold
356
  },
357
+ enable_llm_cache_for_entity_extract=args.enable_llm_cache_for_extract,
358
  embedding_cache_config={
359
  "enabled": True,
360
  "similarity_threshold": 0.95,
 
418
  "doc_status_storage": args.doc_status_storage,
419
  "graph_storage": args.graph_storage,
420
  "vector_storage": args.vector_storage,
421
+ "enable_llm_cache_for_extract": args.enable_llm_cache_for_extract,
422
  },
423
  "update_status": update_status,
424
  }
lightrag/api/utils_api.py CHANGED
@@ -362,6 +362,11 @@ def parse_args(is_uvicorn_mode: bool = False) -> argparse.Namespace:
362
  args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
363
  args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
364
 
 
 
 
 
 
365
  # Select Document loading tool (DOCLING, DEFAULT)
366
  args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
367
 
@@ -457,8 +462,10 @@ def display_splash_screen(args: argparse.Namespace) -> None:
457
  ASCIIColors.yellow(f"{args.history_turns}")
458
  ASCIIColors.white(" ├─ Cosine Threshold: ", end="")
459
  ASCIIColors.yellow(f"{args.cosine_threshold}")
460
- ASCIIColors.white(" └─ Top-K: ", end="")
461
  ASCIIColors.yellow(f"{args.top_k}")
 
 
462
 
463
  # System Configuration
464
  ASCIIColors.magenta("\n💾 Storage Configuration:")
 
362
  args.chunk_size = get_env_value("CHUNK_SIZE", 1200, int)
363
  args.chunk_overlap_size = get_env_value("CHUNK_OVERLAP_SIZE", 100, int)
364
 
365
+ # Inject LLM cache configuration
366
+ args.enable_llm_cache_for_extract = get_env_value(
367
+ "ENABLE_LLM_CACHE_FOR_EXTRACT", False, bool
368
+ )
369
+
370
  # Select Document loading tool (DOCLING, DEFAULT)
371
  args.document_loading_engine = get_env_value("DOCUMENT_LOADING_ENGINE", "DEFAULT")
372
 
 
462
  ASCIIColors.yellow(f"{args.history_turns}")
463
  ASCIIColors.white(" ├─ Cosine Threshold: ", end="")
464
  ASCIIColors.yellow(f"{args.cosine_threshold}")
465
+ ASCIIColors.white(" ├─ Top-K: ", end="")
466
  ASCIIColors.yellow(f"{args.top_k}")
467
+ ASCIIColors.white(" └─ LLM Cache for Extraction Enabled: ", end="")
468
+ ASCIIColors.yellow(f"{args.enable_llm_cache_for_extract}")
469
 
470
  # System Configuration
471
  ASCIIColors.magenta("\n💾 Storage Configuration:")
lightrag/kg/json_doc_status_impl.py CHANGED
@@ -15,6 +15,10 @@ from lightrag.utils import (
15
  from .shared_storage import (
16
  get_namespace_data,
17
  get_storage_lock,
 
 
 
 
18
  try_initialize_namespace,
19
  )
20
 
@@ -27,21 +31,25 @@ class JsonDocStatusStorage(DocStatusStorage):
27
  def __post_init__(self):
28
  working_dir = self.global_config["working_dir"]
29
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
30
- self._storage_lock = get_storage_lock()
31
  self._data = None
 
 
32
 
33
  async def initialize(self):
34
  """Initialize storage data"""
35
- # check need_init must before get_namespace_data
36
- need_init = try_initialize_namespace(self.namespace)
37
- self._data = await get_namespace_data(self.namespace)
38
- if need_init:
39
- loaded_data = load_json(self._file_name) or {}
40
- async with self._storage_lock:
41
- self._data.update(loaded_data)
42
- logger.info(
43
- f"Loaded document status storage with {len(loaded_data)} records"
44
- )
 
 
 
45
 
46
  async def filter_keys(self, keys: set[str]) -> set[str]:
47
  """Return keys that should be processed (not in storage or not successfully processed)"""
@@ -87,18 +95,24 @@ class JsonDocStatusStorage(DocStatusStorage):
87
 
88
  async def index_done_callback(self) -> None:
89
  async with self._storage_lock:
90
- data_dict = (
91
- dict(self._data) if hasattr(self._data, "_getvalue") else self._data
92
- )
93
- write_json(data_dict, self._file_name)
 
 
 
 
 
94
 
95
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
96
- logger.info(f"Inserting {len(data)} to {self.namespace}")
97
  if not data:
98
  return
99
-
100
  async with self._storage_lock:
101
  self._data.update(data)
 
 
102
  await self.index_done_callback()
103
 
104
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
@@ -109,9 +123,12 @@ class JsonDocStatusStorage(DocStatusStorage):
109
  async with self._storage_lock:
110
  for doc_id in doc_ids:
111
  self._data.pop(doc_id, None)
 
112
  await self.index_done_callback()
113
 
114
  async def drop(self) -> None:
115
  """Drop the storage"""
116
  async with self._storage_lock:
117
  self._data.clear()
 
 
 
15
  from .shared_storage import (
16
  get_namespace_data,
17
  get_storage_lock,
18
+ get_data_init_lock,
19
+ get_update_flag,
20
+ set_all_update_flags,
21
+ clear_all_update_flags,
22
  try_initialize_namespace,
23
  )
24
 
 
31
  def __post_init__(self):
32
  working_dir = self.global_config["working_dir"]
33
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
 
34
  self._data = None
35
+ self._storage_lock = None
36
+ self.storage_updated = None
37
 
38
  async def initialize(self):
39
  """Initialize storage data"""
40
+ self._storage_lock = get_storage_lock()
41
+ self.storage_updated = await get_update_flag(self.namespace)
42
+ async with get_data_init_lock():
43
+ # check need_init must before get_namespace_data
44
+ need_init = await try_initialize_namespace(self.namespace)
45
+ self._data = await get_namespace_data(self.namespace)
46
+ if need_init:
47
+ loaded_data = load_json(self._file_name) or {}
48
+ async with self._storage_lock:
49
+ self._data.update(loaded_data)
50
+ logger.info(
51
+ f"Process {os.getpid()} doc status load {self.namespace} with {len(loaded_data)} records"
52
+ )
53
 
54
  async def filter_keys(self, keys: set[str]) -> set[str]:
55
  """Return keys that should be processed (not in storage or not successfully processed)"""
 
95
 
96
  async def index_done_callback(self) -> None:
97
  async with self._storage_lock:
98
+ if self.storage_updated.value:
99
+ data_dict = (
100
+ dict(self._data) if hasattr(self._data, "_getvalue") else self._data
101
+ )
102
+ logger.info(
103
+ f"Process {os.getpid()} doc status writting {len(data_dict)} records to {self.namespace}"
104
+ )
105
+ write_json(data_dict, self._file_name)
106
+ await clear_all_update_flags(self.namespace)
107
 
108
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
109
  if not data:
110
  return
111
+ logger.info(f"Inserting {len(data)} records to {self.namespace}")
112
  async with self._storage_lock:
113
  self._data.update(data)
114
+ await set_all_update_flags(self.namespace)
115
+
116
  await self.index_done_callback()
117
 
118
  async def get_by_id(self, id: str) -> Union[dict[str, Any], None]:
 
123
  async with self._storage_lock:
124
  for doc_id in doc_ids:
125
  self._data.pop(doc_id, None)
126
+ await set_all_update_flags(self.namespace)
127
  await self.index_done_callback()
128
 
129
  async def drop(self) -> None:
130
  """Drop the storage"""
131
  async with self._storage_lock:
132
  self._data.clear()
133
+ await set_all_update_flags(self.namespace)
134
+ await self.index_done_callback()
lightrag/kg/json_kv_impl.py CHANGED
@@ -13,6 +13,10 @@ from lightrag.utils import (
13
  from .shared_storage import (
14
  get_namespace_data,
15
  get_storage_lock,
 
 
 
 
16
  try_initialize_namespace,
17
  )
18
 
@@ -23,26 +27,63 @@ class JsonKVStorage(BaseKVStorage):
23
  def __post_init__(self):
24
  working_dir = self.global_config["working_dir"]
25
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
26
- self._storage_lock = get_storage_lock()
27
  self._data = None
 
 
28
 
29
  async def initialize(self):
30
  """Initialize storage data"""
31
- # check need_init must before get_namespace_data
32
- need_init = try_initialize_namespace(self.namespace)
33
- self._data = await get_namespace_data(self.namespace)
34
- if need_init:
35
- loaded_data = load_json(self._file_name) or {}
36
- async with self._storage_lock:
37
- self._data.update(loaded_data)
38
- logger.info(f"Load KV {self.namespace} with {len(loaded_data)} data")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  async def index_done_callback(self) -> None:
41
  async with self._storage_lock:
42
- data_dict = (
43
- dict(self._data) if hasattr(self._data, "_getvalue") else self._data
44
- )
45
- write_json(data_dict, self._file_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  async def get_all(self) -> dict[str, Any]:
48
  """Get all data from storage
@@ -73,15 +114,16 @@ class JsonKVStorage(BaseKVStorage):
73
  return set(keys) - set(self._data.keys())
74
 
75
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
76
- logger.info(f"Inserting {len(data)} to {self.namespace}")
77
  if not data:
78
  return
 
79
  async with self._storage_lock:
80
- left_data = {k: v for k, v in data.items() if k not in self._data}
81
- self._data.update(left_data)
82
 
83
  async def delete(self, ids: list[str]) -> None:
84
  async with self._storage_lock:
85
  for doc_id in ids:
86
  self._data.pop(doc_id, None)
 
87
  await self.index_done_callback()
 
13
  from .shared_storage import (
14
  get_namespace_data,
15
  get_storage_lock,
16
+ get_data_init_lock,
17
+ get_update_flag,
18
+ set_all_update_flags,
19
+ clear_all_update_flags,
20
  try_initialize_namespace,
21
  )
22
 
 
27
  def __post_init__(self):
28
  working_dir = self.global_config["working_dir"]
29
  self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
 
30
  self._data = None
31
+ self._storage_lock = None
32
+ self.storage_updated = None
33
 
34
  async def initialize(self):
35
  """Initialize storage data"""
36
+ self._storage_lock = get_storage_lock()
37
+ self.storage_updated = await get_update_flag(self.namespace)
38
+ async with get_data_init_lock():
39
+ # check need_init must before get_namespace_data
40
+ need_init = await try_initialize_namespace(self.namespace)
41
+ self._data = await get_namespace_data(self.namespace)
42
+ if need_init:
43
+ loaded_data = load_json(self._file_name) or {}
44
+ async with self._storage_lock:
45
+ self._data.update(loaded_data)
46
+
47
+ # Calculate data count based on namespace
48
+ if self.namespace.endswith("cache"):
49
+ # For cache namespaces, sum the cache entries across all cache types
50
+ data_count = sum(
51
+ len(first_level_dict)
52
+ for first_level_dict in loaded_data.values()
53
+ if isinstance(first_level_dict, dict)
54
+ )
55
+ else:
56
+ # For non-cache namespaces, use the original count method
57
+ data_count = len(loaded_data)
58
+
59
+ logger.info(
60
+ f"Process {os.getpid()} KV load {self.namespace} with {data_count} records"
61
+ )
62
 
63
  async def index_done_callback(self) -> None:
64
  async with self._storage_lock:
65
+ if self.storage_updated.value:
66
+ data_dict = (
67
+ dict(self._data) if hasattr(self._data, "_getvalue") else self._data
68
+ )
69
+
70
+ # Calculate data count based on namespace
71
+ if self.namespace.endswith("cache"):
72
+ # # For cache namespaces, sum the cache entries across all cache types
73
+ data_count = sum(
74
+ len(first_level_dict)
75
+ for first_level_dict in data_dict.values()
76
+ if isinstance(first_level_dict, dict)
77
+ )
78
+ else:
79
+ # For non-cache namespaces, use the original count method
80
+ data_count = len(data_dict)
81
+
82
+ logger.info(
83
+ f"Process {os.getpid()} KV writting {data_count} records to {self.namespace}"
84
+ )
85
+ write_json(data_dict, self._file_name)
86
+ await clear_all_update_flags(self.namespace)
87
 
88
  async def get_all(self) -> dict[str, Any]:
89
  """Get all data from storage
 
114
  return set(keys) - set(self._data.keys())
115
 
116
  async def upsert(self, data: dict[str, dict[str, Any]]) -> None:
 
117
  if not data:
118
  return
119
+ logger.info(f"Inserting {len(data)} records to {self.namespace}")
120
  async with self._storage_lock:
121
+ self._data.update(data)
122
+ await set_all_update_flags(self.namespace)
123
 
124
  async def delete(self, ids: list[str]) -> None:
125
  async with self._storage_lock:
126
  for doc_id in ids:
127
  self._data.pop(doc_id, None)
128
+ await set_all_update_flags(self.namespace)
129
  await self.index_done_callback()
lightrag/kg/neo4j_impl.py CHANGED
@@ -3,7 +3,7 @@ import inspect
3
  import os
4
  import re
5
  from dataclasses import dataclass
6
- from typing import Any, List, Dict, final
7
  import numpy as np
8
  import configparser
9
 
@@ -15,6 +15,7 @@ from tenacity import (
15
  retry_if_exception_type,
16
  )
17
 
 
18
  from ..utils import logger
19
  from ..base import BaseGraphStorage
20
  from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
@@ -37,6 +38,9 @@ config.read("config.ini", "utf-8")
37
  # Get maximum number of graph nodes from environment variable, default is 1000
38
  MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
39
 
 
 
 
40
 
41
  @final
42
  @dataclass
@@ -60,19 +64,25 @@ class Neo4JStorage(BaseGraphStorage):
60
  MAX_CONNECTION_POOL_SIZE = int(
61
  os.environ.get(
62
  "NEO4J_MAX_CONNECTION_POOL_SIZE",
63
- config.get("neo4j", "connection_pool_size", fallback=800),
64
  )
65
  )
66
  CONNECTION_TIMEOUT = float(
67
  os.environ.get(
68
  "NEO4J_CONNECTION_TIMEOUT",
69
- config.get("neo4j", "connection_timeout", fallback=60.0),
70
  ),
71
  )
72
  CONNECTION_ACQUISITION_TIMEOUT = float(
73
  os.environ.get(
74
  "NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
75
- config.get("neo4j", "connection_acquisition_timeout", fallback=60.0),
 
 
 
 
 
 
76
  ),
77
  )
78
  DATABASE = os.environ.get(
@@ -85,6 +95,7 @@ class Neo4JStorage(BaseGraphStorage):
85
  max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
86
  connection_timeout=CONNECTION_TIMEOUT,
87
  connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
 
88
  )
89
 
90
  # Try to connect to the database
@@ -152,65 +163,103 @@ class Neo4JStorage(BaseGraphStorage):
152
  }
153
 
154
  async def close(self):
 
155
  if self._driver:
156
  await self._driver.close()
157
  self._driver = None
158
 
159
  async def __aexit__(self, exc_type, exc, tb):
160
- if self._driver:
161
- await self._driver.close()
162
 
163
  async def index_done_callback(self) -> None:
164
  # Noe4J handles persistence automatically
165
  pass
166
 
167
- async def _label_exists(self, label: str) -> bool:
168
- """Check if a label exists in the Neo4j database."""
169
- query = "CALL db.labels() YIELD label RETURN label"
170
- try:
171
- async with self._driver.session(database=self._DATABASE) as session:
172
- result = await session.run(query)
173
- labels = [record["label"] for record in await result.data()]
174
- return label in labels
175
- except Exception as e:
176
- logger.error(f"Error checking label existence: {e}")
177
- return False
178
 
179
- async def _ensure_label(self, label: str) -> str:
180
- """Ensure a label exists by validating it."""
 
 
 
 
181
  clean_label = label.strip('"')
182
- if not await self._label_exists(clean_label):
183
- logger.warning(f"Label '{clean_label}' does not exist in Neo4j")
184
  return clean_label
185
 
186
  async def has_node(self, node_id: str) -> bool:
187
- entity_name_label = await self._ensure_label(node_id)
188
- async with self._driver.session(database=self._DATABASE) as session:
189
- query = (
190
- f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
191
- )
192
- result = await session.run(query)
193
- single_result = await result.single()
194
- logger.debug(
195
- f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['node_exists']}"
196
- )
197
- return single_result["node_exists"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
200
- entity_name_label_source = source_node_id.strip('"')
201
- entity_name_label_target = target_node_id.strip('"')
202
 
203
- async with self._driver.session(database=self._DATABASE) as session:
204
- query = (
205
- f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
206
- "RETURN COUNT(r) > 0 AS edgeExists"
207
- )
208
- result = await session.run(query)
209
- single_result = await result.single()
210
- logger.debug(
211
- f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{single_result['edgeExists']}"
212
- )
213
- return single_result["edgeExists"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  async def get_node(self, node_id: str) -> dict[str, str] | None:
216
  """Get node by its label identifier.
@@ -221,43 +270,108 @@ class Neo4JStorage(BaseGraphStorage):
221
  Returns:
222
  dict: Node properties if found
223
  None: If node not found
 
 
 
 
224
  """
225
- async with self._driver.session(database=self._DATABASE) as session:
226
- entity_name_label = await self._ensure_label(node_id)
227
- query = f"MATCH (n:`{entity_name_label}`) RETURN n"
228
- result = await session.run(query)
229
- record = await result.single()
230
- if record:
231
- node = record["n"]
232
- node_dict = dict(node)
233
- logger.debug(
234
- f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
235
- )
236
- return node_dict
237
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  async def node_degree(self, node_id: str) -> int:
240
- entity_name_label = node_id.strip('"')
 
 
241
 
242
- async with self._driver.session(database=self._DATABASE) as session:
243
- query = f"""
244
- MATCH (n:`{entity_name_label}`)
245
- RETURN COUNT{{ (n)--() }} AS totalEdgeCount
246
- """
247
- result = await session.run(query)
248
- record = await result.single()
249
- if record:
250
- edge_count = record["totalEdgeCount"]
251
- logger.debug(
252
- f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_count}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  )
254
- return edge_count
255
- else:
256
- return None
257
 
258
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
259
- entity_name_label_source = src_id.strip('"')
260
- entity_name_label_target = tgt_id.strip('"')
 
 
 
 
 
 
 
 
 
 
261
  src_degree = await self.node_degree(entity_name_label_source)
262
  trg_degree = await self.node_degree(entity_name_label_target)
263
 
@@ -266,116 +380,152 @@ class Neo4JStorage(BaseGraphStorage):
266
  trg_degree = 0 if trg_degree is None else trg_degree
267
 
268
  degrees = int(src_degree) + int(trg_degree)
269
- logger.debug(
270
- f"{inspect.currentframe().f_code.co_name}:query:src_Degree+trg_degree:result:{degrees}"
271
- )
272
  return degrees
273
 
274
  async def get_edge(
275
  self, source_node_id: str, target_node_id: str
276
  ) -> dict[str, str] | None:
 
 
 
 
 
 
 
 
 
 
 
 
 
277
  try:
278
- entity_name_label_source = source_node_id.strip('"')
279
- entity_name_label_target = target_node_id.strip('"')
280
 
281
- async with self._driver.session(database=self._DATABASE) as session:
 
 
282
  query = f"""
283
- MATCH (start:`{entity_name_label_source}`)-[r]->(end:`{entity_name_label_target}`)
284
  RETURN properties(r) as edge_properties
285
- LIMIT 1
286
  """
287
 
288
  result = await session.run(query)
289
- record = await result.single()
290
- if record:
291
- try:
292
- result = dict(record["edge_properties"])
293
- logger.info(f"Result: {result}")
294
- # Ensure required keys exist with defaults
295
- required_keys = {
296
- "weight": 0.0,
297
- "source_id": None,
298
- "description": None,
299
- "keywords": None,
300
- }
301
- for key, default_value in required_keys.items():
302
- if key not in result:
303
- result[key] = default_value
304
- logger.warning(
305
- f"Edge between {entity_name_label_source} and {entity_name_label_target} "
306
- f"missing {key}, using default: {default_value}"
307
- )
308
 
309
- logger.debug(
310
- f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{result}"
311
- )
312
- return result
313
- except (KeyError, TypeError, ValueError) as e:
314
- logger.error(
315
- f"Error processing edge properties between {entity_name_label_source} "
316
- f"and {entity_name_label_target}: {str(e)}"
317
  )
318
- # Return default edge properties on error
319
- return {
320
- "weight": 0.0,
321
- "description": None,
322
- "keywords": None,
323
- "source_id": None,
324
- }
325
-
326
- logger.debug(
327
- f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
328
- )
329
- # Return default edge properties when no edge found
330
- return {
331
- "weight": 0.0,
332
- "description": None,
333
- "keywords": None,
334
- "source_id": None,
335
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  except Exception as e:
338
  logger.error(
339
  f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
340
  )
341
- # Return default edge properties on error
342
- return {
343
- "weight": 0.0,
344
- "description": None,
345
- "keywords": None,
346
- "source_id": None,
347
- }
348
 
349
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
350
- node_label = source_node_id.strip('"')
351
 
 
 
 
 
 
 
 
 
 
 
352
  """
353
- Retrieves all edges (relationships) for a particular node identified by its label.
354
- :return: List of dictionaries containing edge information
355
- """
356
- query = f"""MATCH (n:`{node_label}`)
357
- OPTIONAL MATCH (n)-[r]-(connected)
358
- RETURN n, r, connected"""
359
- async with self._driver.session(database=self._DATABASE) as session:
360
- results = await session.run(query)
361
- edges = []
362
- async for record in results:
363
- source_node = record["n"]
364
- connected_node = record["connected"]
365
-
366
- source_label = (
367
- list(source_node.labels)[0] if source_node.labels else None
368
- )
369
- target_label = (
370
- list(connected_node.labels)[0]
371
- if connected_node and connected_node.labels
372
- else None
373
- )
374
 
375
- if source_label and target_label:
376
- edges.append((source_label, target_label))
 
377
 
378
- return edges
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379
 
380
  @retry(
381
  stop=stop_after_attempt(3),
@@ -397,26 +547,88 @@ class Neo4JStorage(BaseGraphStorage):
397
  node_id: The unique identifier for the node (used as label)
398
  node_data: Dictionary of node properties
399
  """
400
- label = await self._ensure_label(node_id)
401
  properties = node_data
402
-
403
- async def _do_upsert(tx: AsyncManagedTransaction):
404
- query = f"""
405
- MERGE (n:`{label}`)
406
- SET n += $properties
407
- """
408
- await tx.run(query, properties=properties)
409
- logger.debug(
410
- f"Upserted node with label '{label}' and properties: {properties}"
411
- )
412
 
413
  try:
414
  async with self._driver.session(database=self._DATABASE) as session:
415
- await session.execute_write(_do_upsert)
 
 
 
 
 
 
 
 
 
 
 
 
416
  except Exception as e:
417
  logger.error(f"Error during upsert: {str(e)}")
418
  raise
419
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  @retry(
421
  stop=stop_after_attempt(3),
422
  wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -434,34 +646,55 @@ class Neo4JStorage(BaseGraphStorage):
434
  ) -> None:
435
  """
436
  Upsert an edge and its properties between two nodes identified by their labels.
 
 
437
 
438
  Args:
439
  source_node_id (str): Label of the source node (used as identifier)
440
  target_node_id (str): Label of the target node (used as identifier)
441
  edge_data (dict): Dictionary of properties to set on the edge
 
 
 
442
  """
443
- source_label = await self._ensure_label(source_node_id)
444
- target_label = await self._ensure_label(target_node_id)
445
  edge_properties = edge_data
446
 
447
- async def _do_upsert_edge(tx: AsyncManagedTransaction):
448
- query = f"""
449
- MATCH (source:`{source_label}`)
450
- WITH source
451
- MATCH (target:`{target_label}`)
452
- MERGE (source)-[r:DIRECTED]->(target)
453
- SET r += $properties
454
- RETURN r
455
- """
456
- result = await tx.run(query, properties=edge_properties)
457
- record = await result.single()
458
- logger.debug(
459
- f"Upserted edge from '{source_label}' to '{target_label}' with properties: {edge_properties}, result: {record['r'] if record else None}"
460
- )
461
 
462
  try:
463
  async with self._driver.session(database=self._DATABASE) as session:
464
- await session.execute_write(_do_upsert_edge)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  except Exception as e:
466
  logger.error(f"Error during edge upsert: {str(e)}")
467
  raise
@@ -470,199 +703,286 @@ class Neo4JStorage(BaseGraphStorage):
470
  print("Implemented but never called.")
471
 
472
  async def get_knowledge_graph(
473
- self, node_label: str, max_depth: int = 5
 
 
 
 
474
  ) -> KnowledgeGraph:
475
  """
476
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
477
  Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
478
  When reducing the number of nodes, the prioritization criteria are as follows:
479
- 1. Label matching nodes take precedence (nodes containing the specified label string)
480
- 2. Followed by nodes directly connected to the matching nodes
481
- 3. Finally, the degree of the nodes
 
482
 
483
  Args:
484
- node_label (str): String to match in node labels (will match any node containing this string in its label)
485
- max_depth (int, optional): Maximum depth of the graph. Defaults to 5.
 
 
486
  Returns:
487
  KnowledgeGraph: Complete connected subgraph for specified node
488
  """
489
  label = node_label.strip('"')
490
- # Escape single quotes to prevent injection attacks
491
- escaped_label = label.replace("'", "\\'")
492
  result = KnowledgeGraph()
493
  seen_nodes = set()
494
  seen_edges = set()
495
 
496
- async with self._driver.session(database=self._DATABASE) as session:
 
 
497
  try:
498
  if label == "*":
499
  main_query = """
500
  MATCH (n)
501
  OPTIONAL MATCH (n)-[r]-()
502
  WITH n, count(r) AS degree
 
503
  ORDER BY degree DESC
504
  LIMIT $max_nodes
505
- WITH collect(n) AS nodes
506
- MATCH (a)-[r]->(b)
507
- WHERE a IN nodes AND b IN nodes
508
- RETURN nodes, collect(DISTINCT r) AS relationships
 
 
 
509
  """
510
  result_set = await session.run(
511
- main_query, {"max_nodes": MAX_GRAPH_NODES}
 
512
  )
513
 
514
  else:
515
- validate_query = f"""
516
- MATCH (n)
517
- WHERE any(label IN labels(n) WHERE label CONTAINS '{escaped_label}')
518
- RETURN n LIMIT 1
519
- """
520
- validate_result = await session.run(validate_query)
521
- if not await validate_result.single():
522
- logger.warning(
523
- f"No nodes containing '{label}' in their labels found!"
524
- )
525
- return result
526
-
527
  # Main query uses partial matching
528
- main_query = f"""
529
  MATCH (start)
530
- WHERE any(label IN labels(start) WHERE label CONTAINS '{escaped_label}')
 
 
 
 
 
531
  WITH start
532
- CALL apoc.path.subgraphAll(start, {{
533
- relationshipFilter: '>',
534
  minLevel: 0,
535
- maxLevel: {max_depth},
536
  bfs: true
537
- }})
538
  YIELD nodes, relationships
539
  WITH start, nodes, relationships
540
  UNWIND nodes AS node
541
  OPTIONAL MATCH (node)-[r]-()
542
- WITH node, count(r) AS degree, start, nodes, relationships,
543
- CASE
544
- WHEN id(node) = id(start) THEN 2
545
- WHEN EXISTS((start)-->(node)) OR EXISTS((node)-->(start)) THEN 1
546
- ELSE 0
547
- END AS priority
548
- ORDER BY priority DESC, degree DESC
 
 
549
  LIMIT $max_nodes
550
- WITH collect(node) AS filtered_nodes, nodes, relationships
551
- RETURN filtered_nodes AS nodes,
552
- [rel IN relationships WHERE startNode(rel) IN filtered_nodes AND endNode(rel) IN filtered_nodes] AS relationships
 
 
 
 
553
  """
554
  result_set = await session.run(
555
- main_query, {"max_nodes": MAX_GRAPH_NODES}
 
 
 
 
 
 
 
556
  )
557
 
558
- record = await result_set.single()
559
-
560
- if record:
561
- # Handle nodes (compatible with multi-label cases)
562
- for node in record["nodes"]:
563
- # Use node ID + label combination as unique identifier
564
- node_id = node.id
565
- if node_id not in seen_nodes:
566
- result.nodes.append(
567
- KnowledgeGraphNode(
568
- id=f"{node_id}",
569
- labels=list(node.labels),
570
- properties=dict(node),
 
 
571
  )
572
- )
573
- seen_nodes.add(node_id)
574
-
575
- # Handle relationships (including direction information)
576
- for rel in record["relationships"]:
577
- edge_id = rel.id
578
- if edge_id not in seen_edges:
579
- start = rel.start_node
580
- end = rel.end_node
581
- result.edges.append(
582
- KnowledgeGraphEdge(
583
- id=f"{edge_id}",
584
- type=rel.type,
585
- source=f"{start.id}",
586
- target=f"{end.id}",
587
- properties=dict(rel),
588
  )
589
- )
590
- seen_edges.add(edge_id)
591
 
592
- logger.info(
593
- f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
594
- )
 
 
595
 
596
  except neo4jExceptions.ClientError as e:
597
- logger.error(f"APOC query failed: {str(e)}")
598
- return await self._robust_fallback(label, max_depth)
 
 
 
 
 
 
 
 
599
 
600
  return result
601
 
602
  async def _robust_fallback(
603
- self, label: str, max_depth: int
604
- ) -> Dict[str, List[Dict]]:
605
- """Enhanced fallback query solution"""
606
- result = {"nodes": [], "edges": []}
 
 
 
 
607
  visited_nodes = set()
608
  visited_edges = set()
609
 
610
- async def traverse(current_label: str, current_depth: int):
 
 
 
 
 
611
  if current_depth > max_depth:
 
612
  return
613
-
614
- # Get current node details
615
- node = await self.get_node(current_label)
616
- if not node:
617
  return
618
 
619
- node_id = f"{current_label}"
620
- if node_id in visited_nodes:
621
  return
622
- visited_nodes.add(node_id)
623
 
624
- # Add node data (with complete labels)
625
- node_data = {k: v for k, v in node.items()}
626
- node_data["labels"] = [
627
- current_label
628
- ] # Assume get_node method returns label information
629
- result["nodes"].append(node_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
630
 
631
- # Get all outgoing and incoming edges
 
 
 
632
  query = f"""
633
- MATCH (a)-[r]-(b)
634
- WHERE a:`{current_label}` OR b:`{current_label}`
635
- RETURN a, r, b,
636
- CASE WHEN startNode(r) = a THEN 'OUTGOING' ELSE 'INCOMING' END AS direction
637
  """
638
- async with self._driver.session(database=self._DATABASE) as session:
639
- results = await session.run(query)
640
- async for record in results:
641
- # Handle edges
642
- rel = record["r"]
643
- edge_id = f"{rel.id}_{rel.type}"
644
- if edge_id not in visited_edges:
645
- edge_data = dict(rel)
646
- edge_data.update(
647
- {
648
- "source": list(record["a"].labels)[0],
649
- "target": list(record["b"].labels)[0],
650
- "type": rel.type,
651
- "direction": record["direction"],
652
- }
653
- )
654
- result["edges"].append(edge_data)
655
- visited_edges.add(edge_id)
656
-
657
- # Recursively traverse adjacent nodes
658
- next_label = (
659
- list(record["b"].labels)[0]
660
- if record["direction"] == "OUTGOING"
661
- else list(record["a"].labels)[0]
662
- )
663
- await traverse(next_label, current_depth + 1)
664
 
665
- await traverse(label, 0)
666
  return result
667
 
668
  async def get_all_labels(self) -> list[str]:
@@ -671,7 +991,9 @@ class Neo4JStorage(BaseGraphStorage):
671
  Returns:
672
  ["Person", "Company", ...] # Alphabetically sorted label list
673
  """
674
- async with self._driver.session(database=self._DATABASE) as session:
 
 
675
  # Method 1: Direct metadata query (Available for Neo4j 4.3+)
676
  # query = "CALL db.labels() YIELD label RETURN label"
677
 
@@ -683,11 +1005,15 @@ class Neo4JStorage(BaseGraphStorage):
683
  RETURN DISTINCT label
684
  ORDER BY label
685
  """
686
-
687
  result = await session.run(query)
688
  labels = []
689
- async for record in result:
690
- labels.append(record["label"])
 
 
 
 
 
691
  return labels
692
 
693
  @retry(
@@ -708,15 +1034,16 @@ class Neo4JStorage(BaseGraphStorage):
708
  Args:
709
  node_id: The label of the node to delete
710
  """
711
- label = await self._ensure_label(node_id)
712
 
713
  async def _do_delete(tx: AsyncManagedTransaction):
714
  query = f"""
715
  MATCH (n:`{label}`)
716
  DETACH DELETE n
717
  """
718
- await tx.run(query)
719
  logger.debug(f"Deleted node with label '{label}'")
 
720
 
721
  try:
722
  async with self._driver.session(database=self._DATABASE) as session:
@@ -765,16 +1092,17 @@ class Neo4JStorage(BaseGraphStorage):
765
  edges: List of edges to be deleted, each edge is a (source, target) tuple
766
  """
767
  for source, target in edges:
768
- source_label = await self._ensure_label(source)
769
- target_label = await self._ensure_label(target)
770
 
771
  async def _do_delete_edge(tx: AsyncManagedTransaction):
772
  query = f"""
773
- MATCH (source:`{source_label}`)-[r]->(target:`{target_label}`)
774
  DELETE r
775
  """
776
- await tx.run(query)
777
  logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'")
 
778
 
779
  try:
780
  async with self._driver.session(database=self._DATABASE) as session:
 
3
  import os
4
  import re
5
  from dataclasses import dataclass
6
+ from typing import Any, final, Optional
7
  import numpy as np
8
  import configparser
9
 
 
15
  retry_if_exception_type,
16
  )
17
 
18
+ import logging
19
  from ..utils import logger
20
  from ..base import BaseGraphStorage
21
  from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
 
38
  # Get maximum number of graph nodes from environment variable, default is 1000
39
  MAX_GRAPH_NODES = int(os.getenv("MAX_GRAPH_NODES", 1000))
40
 
41
+ # Set neo4j logger level to ERROR to suppress warning logs
42
+ logging.getLogger("neo4j").setLevel(logging.ERROR)
43
+
44
 
45
  @final
46
  @dataclass
 
64
  MAX_CONNECTION_POOL_SIZE = int(
65
  os.environ.get(
66
  "NEO4J_MAX_CONNECTION_POOL_SIZE",
67
+ config.get("neo4j", "connection_pool_size", fallback=50),
68
  )
69
  )
70
  CONNECTION_TIMEOUT = float(
71
  os.environ.get(
72
  "NEO4J_CONNECTION_TIMEOUT",
73
+ config.get("neo4j", "connection_timeout", fallback=30.0),
74
  ),
75
  )
76
  CONNECTION_ACQUISITION_TIMEOUT = float(
77
  os.environ.get(
78
  "NEO4J_CONNECTION_ACQUISITION_TIMEOUT",
79
+ config.get("neo4j", "connection_acquisition_timeout", fallback=30.0),
80
+ ),
81
+ )
82
+ MAX_TRANSACTION_RETRY_TIME = float(
83
+ os.environ.get(
84
+ "NEO4J_MAX_TRANSACTION_RETRY_TIME",
85
+ config.get("neo4j", "max_transaction_retry_time", fallback=30.0),
86
  ),
87
  )
88
  DATABASE = os.environ.get(
 
95
  max_connection_pool_size=MAX_CONNECTION_POOL_SIZE,
96
  connection_timeout=CONNECTION_TIMEOUT,
97
  connection_acquisition_timeout=CONNECTION_ACQUISITION_TIMEOUT,
98
+ max_transaction_retry_time=MAX_TRANSACTION_RETRY_TIME,
99
  )
100
 
101
  # Try to connect to the database
 
163
  }
164
 
165
  async def close(self):
166
+ """Close the Neo4j driver and release all resources"""
167
  if self._driver:
168
  await self._driver.close()
169
  self._driver = None
170
 
171
  async def __aexit__(self, exc_type, exc, tb):
172
+ """Ensure driver is closed when context manager exits"""
173
+ await self.close()
174
 
175
  async def index_done_callback(self) -> None:
176
  # Noe4J handles persistence automatically
177
  pass
178
 
179
+ def _ensure_label(self, label: str) -> str:
180
+ """Ensure a label is valid
181
+
182
+ Args:
183
+ label: The label to validate
 
 
 
 
 
 
184
 
185
+ Returns:
186
+ str: The cleaned label
187
+
188
+ Raises:
189
+ ValueError: If label is empty after cleaning
190
+ """
191
  clean_label = label.strip('"')
192
+ if not clean_label:
193
+ raise ValueError("Neo4j: Label cannot be empty")
194
  return clean_label
195
 
196
  async def has_node(self, node_id: str) -> bool:
197
+ """
198
+ Check if a node with the given label exists in the database
199
+
200
+ Args:
201
+ node_id: Label of the node to check
202
+
203
+ Returns:
204
+ bool: True if node exists, False otherwise
205
+
206
+ Raises:
207
+ ValueError: If node_id is invalid
208
+ Exception: If there is an error executing the query
209
+ """
210
+ entity_name_label = self._ensure_label(node_id)
211
+ async with self._driver.session(
212
+ database=self._DATABASE, default_access_mode="READ"
213
+ ) as session:
214
+ try:
215
+ query = f"MATCH (n:`{entity_name_label}`) RETURN count(n) > 0 AS node_exists"
216
+ result = await session.run(query)
217
+ single_result = await result.single()
218
+ await result.consume() # Ensure result is fully consumed
219
+ return single_result["node_exists"]
220
+ except Exception as e:
221
+ logger.error(
222
+ f"Error checking node existence for {entity_name_label}: {str(e)}"
223
+ )
224
+ await result.consume() # Ensure results are consumed even on error
225
+ raise
226
 
227
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
228
+ """
229
+ Check if an edge exists between two nodes
230
 
231
+ Args:
232
+ source_node_id: Label of the source node
233
+ target_node_id: Label of the target node
234
+
235
+ Returns:
236
+ bool: True if edge exists, False otherwise
237
+
238
+ Raises:
239
+ ValueError: If either node_id is invalid
240
+ Exception: If there is an error executing the query
241
+ """
242
+ entity_name_label_source = self._ensure_label(source_node_id)
243
+ entity_name_label_target = self._ensure_label(target_node_id)
244
+
245
+ async with self._driver.session(
246
+ database=self._DATABASE, default_access_mode="READ"
247
+ ) as session:
248
+ try:
249
+ query = (
250
+ f"MATCH (a:`{entity_name_label_source}`)-[r]-(b:`{entity_name_label_target}`) "
251
+ "RETURN COUNT(r) > 0 AS edgeExists"
252
+ )
253
+ result = await session.run(query)
254
+ single_result = await result.single()
255
+ await result.consume() # Ensure result is fully consumed
256
+ return single_result["edgeExists"]
257
+ except Exception as e:
258
+ logger.error(
259
+ f"Error checking edge existence between {entity_name_label_source} and {entity_name_label_target}: {str(e)}"
260
+ )
261
+ await result.consume() # Ensure results are consumed even on error
262
+ raise
263
 
264
  async def get_node(self, node_id: str) -> dict[str, str] | None:
265
  """Get node by its label identifier.
 
270
  Returns:
271
  dict: Node properties if found
272
  None: If node not found
273
+
274
+ Raises:
275
+ ValueError: If node_id is invalid
276
+ Exception: If there is an error executing the query
277
  """
278
+ entity_name_label = self._ensure_label(node_id)
279
+ async with self._driver.session(
280
+ database=self._DATABASE, default_access_mode="READ"
281
+ ) as session:
282
+ try:
283
+ query = f"MATCH (n:`{entity_name_label}` {{entity_id: $entity_id}}) RETURN n"
284
+ result = await session.run(query, entity_id=entity_name_label)
285
+ try:
286
+ records = await result.fetch(
287
+ 2
288
+ ) # Get 2 records for duplication check
289
+
290
+ if len(records) > 1:
291
+ logger.warning(
292
+ f"Multiple nodes found with label '{entity_name_label}'. Using first node."
293
+ )
294
+ if records:
295
+ node = records[0]["n"]
296
+ node_dict = dict(node)
297
+ logger.debug(
298
+ f"{inspect.currentframe().f_code.co_name}: query: {query}, result: {node_dict}"
299
+ )
300
+ return node_dict
301
+ return None
302
+ finally:
303
+ await result.consume() # Ensure result is fully consumed
304
+ except Exception as e:
305
+ logger.error(f"Error getting node for {entity_name_label}: {str(e)}")
306
+ raise
307
 
308
  async def node_degree(self, node_id: str) -> int:
309
+ """Get the degree (number of relationships) of a node with the given label.
310
+ If multiple nodes have the same label, returns the degree of the first node.
311
+ If no node is found, returns 0.
312
 
313
+ Args:
314
+ node_id: The label of the node
315
+
316
+ Returns:
317
+ int: The number of relationships the node has, or 0 if no node found
318
+
319
+ Raises:
320
+ ValueError: If node_id is invalid
321
+ Exception: If there is an error executing the query
322
+ """
323
+ entity_name_label = self._ensure_label(node_id)
324
+
325
+ async with self._driver.session(
326
+ database=self._DATABASE, default_access_mode="READ"
327
+ ) as session:
328
+ try:
329
+ query = f"""
330
+ MATCH (n:`{entity_name_label}`)
331
+ OPTIONAL MATCH (n)-[r]-()
332
+ RETURN n, COUNT(r) AS degree
333
+ """
334
+ result = await session.run(query)
335
+ try:
336
+ records = await result.fetch(100)
337
+
338
+ if not records:
339
+ logger.warning(
340
+ f"No node found with label '{entity_name_label}'"
341
+ )
342
+ return 0
343
+
344
+ if len(records) > 1:
345
+ logger.warning(
346
+ f"Multiple nodes ({len(records)}) found with label '{entity_name_label}', using first node's degree"
347
+ )
348
+
349
+ degree = records[0]["degree"]
350
+ logger.debug(
351
+ f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{degree}"
352
+ )
353
+ return degree
354
+ finally:
355
+ await result.consume() # Ensure result is fully consumed
356
+ except Exception as e:
357
+ logger.error(
358
+ f"Error getting node degree for {entity_name_label}: {str(e)}"
359
  )
360
+ raise
 
 
361
 
362
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
363
+ """Get the total degree (sum of relationships) of two nodes.
364
+
365
+ Args:
366
+ src_id: Label of the source node
367
+ tgt_id: Label of the target node
368
+
369
+ Returns:
370
+ int: Sum of the degrees of both nodes
371
+ """
372
+ entity_name_label_source = self._ensure_label(src_id)
373
+ entity_name_label_target = self._ensure_label(tgt_id)
374
+
375
  src_degree = await self.node_degree(entity_name_label_source)
376
  trg_degree = await self.node_degree(entity_name_label_target)
377
 
 
380
  trg_degree = 0 if trg_degree is None else trg_degree
381
 
382
  degrees = int(src_degree) + int(trg_degree)
 
 
 
383
  return degrees
384
 
385
  async def get_edge(
386
  self, source_node_id: str, target_node_id: str
387
  ) -> dict[str, str] | None:
388
+ """Get edge properties between two nodes.
389
+
390
+ Args:
391
+ source_node_id: Label of the source node
392
+ target_node_id: Label of the target node
393
+
394
+ Returns:
395
+ dict: Edge properties if found, default properties if not found or on error
396
+
397
+ Raises:
398
+ ValueError: If either node_id is invalid
399
+ Exception: If there is an error executing the query
400
+ """
401
  try:
402
+ entity_name_label_source = self._ensure_label(source_node_id)
403
+ entity_name_label_target = self._ensure_label(target_node_id)
404
 
405
+ async with self._driver.session(
406
+ database=self._DATABASE, default_access_mode="READ"
407
+ ) as session:
408
  query = f"""
409
+ MATCH (start:`{entity_name_label_source}`)-[r]-(end:`{entity_name_label_target}`)
410
  RETURN properties(r) as edge_properties
 
411
  """
412
 
413
  result = await session.run(query)
414
+ try:
415
+ records = await result.fetch(2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
417
+ if len(records) > 1:
418
+ logger.warning(
419
+ f"Multiple edges found between '{entity_name_label_source}' and '{entity_name_label_target}'. Using first edge."
 
 
 
 
 
420
  )
421
+ if records:
422
+ try:
423
+ edge_result = dict(records[0]["edge_properties"])
424
+ logger.debug(f"Result: {edge_result}")
425
+ # Ensure required keys exist with defaults
426
+ required_keys = {
427
+ "weight": 0.0,
428
+ "source_id": None,
429
+ "description": None,
430
+ "keywords": None,
431
+ }
432
+ for key, default_value in required_keys.items():
433
+ if key not in edge_result:
434
+ edge_result[key] = default_value
435
+ logger.warning(
436
+ f"Edge between {entity_name_label_source} and {entity_name_label_target} "
437
+ f"missing {key}, using default: {default_value}"
438
+ )
439
+
440
+ logger.debug(
441
+ f"{inspect.currentframe().f_code.co_name}:query:{query}:result:{edge_result}"
442
+ )
443
+ return edge_result
444
+ except (KeyError, TypeError, ValueError) as e:
445
+ logger.error(
446
+ f"Error processing edge properties between {entity_name_label_source} "
447
+ f"and {entity_name_label_target}: {str(e)}"
448
+ )
449
+ # Return default edge properties on error
450
+ return {
451
+ "weight": 0.0,
452
+ "source_id": None,
453
+ "description": None,
454
+ "keywords": None,
455
+ }
456
+
457
+ logger.debug(
458
+ f"{inspect.currentframe().f_code.co_name}: No edge found between {entity_name_label_source} and {entity_name_label_target}"
459
+ )
460
+ # Return default edge properties when no edge found
461
+ return {
462
+ "weight": 0.0,
463
+ "source_id": None,
464
+ "description": None,
465
+ "keywords": None,
466
+ }
467
+ finally:
468
+ await result.consume() # Ensure result is fully consumed
469
 
470
  except Exception as e:
471
  logger.error(
472
  f"Error in get_edge between {source_node_id} and {target_node_id}: {str(e)}"
473
  )
474
+ raise
 
 
 
 
 
 
475
 
476
  async def get_node_edges(self, source_node_id: str) -> list[tuple[str, str]] | None:
477
+ """Retrieves all edges (relationships) for a particular node identified by its label.
478
 
479
+ Args:
480
+ source_node_id: Label of the node to get edges for
481
+
482
+ Returns:
483
+ list[tuple[str, str]]: List of (source_label, target_label) tuples representing edges
484
+ None: If no edges found
485
+
486
+ Raises:
487
+ ValueError: If source_node_id is invalid
488
+ Exception: If there is an error executing the query
489
  """
490
+ try:
491
+ node_label = self._ensure_label(source_node_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
492
 
493
+ query = f"""MATCH (n:`{node_label}`)
494
+ OPTIONAL MATCH (n)-[r]-(connected)
495
+ RETURN n, r, connected"""
496
 
497
+ async with self._driver.session(
498
+ database=self._DATABASE, default_access_mode="READ"
499
+ ) as session:
500
+ try:
501
+ results = await session.run(query)
502
+ edges = []
503
+
504
+ async for record in results:
505
+ source_node = record["n"]
506
+ connected_node = record["connected"]
507
+
508
+ source_label = (
509
+ list(source_node.labels)[0] if source_node.labels else None
510
+ )
511
+ target_label = (
512
+ list(connected_node.labels)[0]
513
+ if connected_node and connected_node.labels
514
+ else None
515
+ )
516
+
517
+ if source_label and target_label:
518
+ edges.append((source_label, target_label))
519
+
520
+ await results.consume() # Ensure results are consumed
521
+ return edges
522
+ except Exception as e:
523
+ logger.error(f"Error getting edges for node {node_label}: {str(e)}")
524
+ await results.consume() # Ensure results are consumed even on error
525
+ raise
526
+ except Exception as e:
527
+ logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
528
+ raise
529
 
530
  @retry(
531
  stop=stop_after_attempt(3),
 
547
  node_id: The unique identifier for the node (used as label)
548
  node_data: Dictionary of node properties
549
  """
550
+ label = self._ensure_label(node_id)
551
  properties = node_data
552
+ if "entity_id" not in properties:
553
+ raise ValueError("Neo4j: node properties must contain an 'entity_id' field")
 
 
 
 
 
 
 
 
554
 
555
  try:
556
  async with self._driver.session(database=self._DATABASE) as session:
557
+
558
+ async def execute_upsert(tx: AsyncManagedTransaction):
559
+ query = f"""
560
+ MERGE (n:`{label}` {{entity_id: $properties.entity_id}})
561
+ SET n += $properties
562
+ """
563
+ result = await tx.run(query, properties=properties)
564
+ logger.debug(
565
+ f"Upserted node with label '{label}' and properties: {properties}"
566
+ )
567
+ await result.consume() # Ensure result is fully consumed
568
+
569
+ await session.execute_write(execute_upsert)
570
  except Exception as e:
571
  logger.error(f"Error during upsert: {str(e)}")
572
  raise
573
 
574
+ @retry(
575
+ stop=stop_after_attempt(3),
576
+ wait=wait_exponential(multiplier=1, min=4, max=10),
577
+ retry=retry_if_exception_type(
578
+ (
579
+ neo4jExceptions.ServiceUnavailable,
580
+ neo4jExceptions.TransientError,
581
+ neo4jExceptions.WriteServiceUnavailable,
582
+ neo4jExceptions.ClientError,
583
+ )
584
+ ),
585
+ )
586
+ async def _get_unique_node_entity_id(self, node_label: str) -> str:
587
+ """
588
+ Get the entity_id of a node with the given label, ensuring the node is unique.
589
+
590
+ Args:
591
+ node_label (str): Label of the node to check
592
+
593
+ Returns:
594
+ str: The entity_id of the unique node
595
+
596
+ Raises:
597
+ ValueError: If no node with the given label exists or if multiple nodes have the same label
598
+ """
599
+ async with self._driver.session(
600
+ database=self._DATABASE, default_access_mode="READ"
601
+ ) as session:
602
+ query = f"""
603
+ MATCH (n:`{node_label}`)
604
+ RETURN n, count(n) as node_count
605
+ """
606
+ result = await session.run(query)
607
+ try:
608
+ records = await result.fetch(
609
+ 2
610
+ ) # We only need to know if there are 0, 1, or >1 nodes
611
+
612
+ if not records or records[0]["node_count"] == 0:
613
+ raise ValueError(
614
+ f"Neo4j: node with label '{node_label}' does not exist"
615
+ )
616
+
617
+ if records[0]["node_count"] > 1:
618
+ raise ValueError(
619
+ f"Neo4j: multiple nodes found with label '{node_label}', cannot determine unique node"
620
+ )
621
+
622
+ node = records[0]["n"]
623
+ if "entity_id" not in node:
624
+ raise ValueError(
625
+ f"Neo4j: node with label '{node_label}' does not have an entity_id property"
626
+ )
627
+
628
+ return node["entity_id"]
629
+ finally:
630
+ await result.consume() # Ensure result is fully consumed
631
+
632
  @retry(
633
  stop=stop_after_attempt(3),
634
  wait=wait_exponential(multiplier=1, min=4, max=10),
 
646
  ) -> None:
647
  """
648
  Upsert an edge and its properties between two nodes identified by their labels.
649
+ Ensures both source and target nodes exist and are unique before creating the edge.
650
+ Uses entity_id property to uniquely identify nodes.
651
 
652
  Args:
653
  source_node_id (str): Label of the source node (used as identifier)
654
  target_node_id (str): Label of the target node (used as identifier)
655
  edge_data (dict): Dictionary of properties to set on the edge
656
+
657
+ Raises:
658
+ ValueError: If either source or target node does not exist or is not unique
659
  """
660
+ source_label = self._ensure_label(source_node_id)
661
+ target_label = self._ensure_label(target_node_id)
662
  edge_properties = edge_data
663
 
664
+ # Get entity_ids for source and target nodes, ensuring they are unique
665
+ source_entity_id = await self._get_unique_node_entity_id(source_label)
666
+ target_entity_id = await self._get_unique_node_entity_id(target_label)
 
 
 
 
 
 
 
 
 
 
 
667
 
668
  try:
669
  async with self._driver.session(database=self._DATABASE) as session:
670
+
671
+ async def execute_upsert(tx: AsyncManagedTransaction):
672
+ query = f"""
673
+ MATCH (source:`{source_label}` {{entity_id: $source_entity_id}})
674
+ WITH source
675
+ MATCH (target:`{target_label}` {{entity_id: $target_entity_id}})
676
+ MERGE (source)-[r:DIRECTED]-(target)
677
+ SET r += $properties
678
+ RETURN r, source, target
679
+ """
680
+ result = await tx.run(
681
+ query,
682
+ source_entity_id=source_entity_id,
683
+ target_entity_id=target_entity_id,
684
+ properties=edge_properties,
685
+ )
686
+ try:
687
+ records = await result.fetch(100)
688
+ if records:
689
+ logger.debug(
690
+ f"Upserted edge from '{source_label}' (entity_id: {source_entity_id}) "
691
+ f"to '{target_label}' (entity_id: {target_entity_id}) "
692
+ f"with properties: {edge_properties}"
693
+ )
694
+ finally:
695
+ await result.consume() # Ensure result is consumed
696
+
697
+ await session.execute_write(execute_upsert)
698
  except Exception as e:
699
  logger.error(f"Error during edge upsert: {str(e)}")
700
  raise
 
703
  print("Implemented but never called.")
704
 
705
  async def get_knowledge_graph(
706
+ self,
707
+ node_label: str,
708
+ max_depth: int = 3,
709
+ min_degree: int = 0,
710
+ inclusive: bool = False,
711
  ) -> KnowledgeGraph:
712
  """
713
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
714
  Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
715
  When reducing the number of nodes, the prioritization criteria are as follows:
716
+ 1. min_degree does not affect nodes directly connected to the matching nodes
717
+ 2. Label matching nodes take precedence
718
+ 3. Followed by nodes directly connected to the matching nodes
719
+ 4. Finally, the degree of the nodes
720
 
721
  Args:
722
+ node_label: Label of the starting node
723
+ max_depth: Maximum depth of the subgraph
724
+ min_degree: Minimum degree of nodes to include. Defaults to 0
725
+ inclusive: Do an inclusive search if true
726
  Returns:
727
  KnowledgeGraph: Complete connected subgraph for specified node
728
  """
729
  label = node_label.strip('"')
 
 
730
  result = KnowledgeGraph()
731
  seen_nodes = set()
732
  seen_edges = set()
733
 
734
+ async with self._driver.session(
735
+ database=self._DATABASE, default_access_mode="READ"
736
+ ) as session:
737
  try:
738
  if label == "*":
739
  main_query = """
740
  MATCH (n)
741
  OPTIONAL MATCH (n)-[r]-()
742
  WITH n, count(r) AS degree
743
+ WHERE degree >= $min_degree
744
  ORDER BY degree DESC
745
  LIMIT $max_nodes
746
+ WITH collect({node: n}) AS filtered_nodes
747
+ UNWIND filtered_nodes AS node_info
748
+ WITH collect(node_info.node) AS kept_nodes, filtered_nodes
749
+ MATCH (a)-[r]-(b)
750
+ WHERE a IN kept_nodes AND b IN kept_nodes
751
+ RETURN filtered_nodes AS node_info,
752
+ collect(DISTINCT r) AS relationships
753
  """
754
  result_set = await session.run(
755
+ main_query,
756
+ {"max_nodes": MAX_GRAPH_NODES, "min_degree": min_degree},
757
  )
758
 
759
  else:
 
 
 
 
 
 
 
 
 
 
 
 
760
  # Main query uses partial matching
761
+ main_query = """
762
  MATCH (start)
763
+ WHERE any(label IN labels(start) WHERE
764
+ CASE
765
+ WHEN $inclusive THEN label CONTAINS $label
766
+ ELSE label = $label
767
+ END
768
+ )
769
  WITH start
770
+ CALL apoc.path.subgraphAll(start, {
771
+ relationshipFilter: '',
772
  minLevel: 0,
773
+ maxLevel: $max_depth,
774
  bfs: true
775
+ })
776
  YIELD nodes, relationships
777
  WITH start, nodes, relationships
778
  UNWIND nodes AS node
779
  OPTIONAL MATCH (node)-[r]-()
780
+ WITH node, count(r) AS degree, start, nodes, relationships
781
+ WHERE node = start OR EXISTS((start)--(node)) OR degree >= $min_degree
782
+ ORDER BY
783
+ CASE
784
+ WHEN node = start THEN 3
785
+ WHEN EXISTS((start)--(node)) THEN 2
786
+ ELSE 1
787
+ END DESC,
788
+ degree DESC
789
  LIMIT $max_nodes
790
+ WITH collect({node: node}) AS filtered_nodes
791
+ UNWIND filtered_nodes AS node_info
792
+ WITH collect(node_info.node) AS kept_nodes, filtered_nodes
793
+ MATCH (a)-[r]-(b)
794
+ WHERE a IN kept_nodes AND b IN kept_nodes
795
+ RETURN filtered_nodes AS node_info,
796
+ collect(DISTINCT r) AS relationships
797
  """
798
  result_set = await session.run(
799
+ main_query,
800
+ {
801
+ "max_nodes": MAX_GRAPH_NODES,
802
+ "label": label,
803
+ "inclusive": inclusive,
804
+ "max_depth": max_depth,
805
+ "min_degree": min_degree,
806
+ },
807
  )
808
 
809
+ try:
810
+ record = await result_set.single()
811
+
812
+ if record:
813
+ # Handle nodes (compatible with multi-label cases)
814
+ for node_info in record["node_info"]:
815
+ node = node_info["node"]
816
+ node_id = node.id
817
+ if node_id not in seen_nodes:
818
+ result.nodes.append(
819
+ KnowledgeGraphNode(
820
+ id=f"{node_id}",
821
+ labels=list(node.labels),
822
+ properties=dict(node),
823
+ )
824
  )
825
+ seen_nodes.add(node_id)
826
+
827
+ # Handle relationships (including direction information)
828
+ for rel in record["relationships"]:
829
+ edge_id = rel.id
830
+ if edge_id not in seen_edges:
831
+ start = rel.start_node
832
+ end = rel.end_node
833
+ result.edges.append(
834
+ KnowledgeGraphEdge(
835
+ id=f"{edge_id}",
836
+ type=rel.type,
837
+ source=f"{start.id}",
838
+ target=f"{end.id}",
839
+ properties=dict(rel),
840
+ )
841
  )
842
+ seen_edges.add(edge_id)
 
843
 
844
+ logger.info(
845
+ f"Process {os.getpid()} graph query return: {len(result.nodes)} nodes, {len(result.edges)} edges"
846
+ )
847
+ finally:
848
+ await result_set.consume() # Ensure result set is consumed
849
 
850
  except neo4jExceptions.ClientError as e:
851
+ logger.warning(f"APOC plugin error: {str(e)}")
852
+ if label != "*":
853
+ logger.warning(
854
+ "Neo4j: falling back to basic Cypher recursive search..."
855
+ )
856
+ if inclusive:
857
+ logger.warning(
858
+ "Neo4j: inclusive search mode is not supported in recursive query, using exact matching"
859
+ )
860
+ return await self._robust_fallback(label, max_depth, min_degree)
861
 
862
  return result
863
 
864
  async def _robust_fallback(
865
+ self, label: str, max_depth: int, min_degree: int = 0
866
+ ) -> KnowledgeGraph:
867
+ """
868
+ Fallback implementation when APOC plugin is not available or incompatible.
869
+ This method implements the same functionality as get_knowledge_graph but uses
870
+ only basic Cypher queries and recursive traversal instead of APOC procedures.
871
+ """
872
+ result = KnowledgeGraph()
873
  visited_nodes = set()
874
  visited_edges = set()
875
 
876
+ async def traverse(
877
+ node: KnowledgeGraphNode,
878
+ edge: Optional[KnowledgeGraphEdge],
879
+ current_depth: int,
880
+ ):
881
+ # Check traversal limits
882
  if current_depth > max_depth:
883
+ logger.debug(f"Reached max depth: {max_depth}")
884
  return
885
+ if len(visited_nodes) >= MAX_GRAPH_NODES:
886
+ logger.debug(f"Reached max nodes limit: {MAX_GRAPH_NODES}")
 
 
887
  return
888
 
889
+ # Check if node already visited
890
+ if node.id in visited_nodes:
891
  return
 
892
 
893
+ # Get all edges and target nodes
894
+ async with self._driver.session(
895
+ database=self._DATABASE, default_access_mode="READ"
896
+ ) as session:
897
+ query = """
898
+ MATCH (a)-[r]-(b)
899
+ WHERE id(a) = toInteger($node_id)
900
+ WITH r, b, id(r) as edge_id, id(b) as target_id
901
+ RETURN r, b, edge_id, target_id
902
+ """
903
+ results = await session.run(query, {"node_id": node.id})
904
+
905
+ # Get all records and release database connection
906
+ records = await results.fetch(
907
+ 1000
908
+ ) # Max neighbour nodes we can handled
909
+ await results.consume() # Ensure results are consumed
910
+
911
+ # Nodes not connected to start node need to check degree
912
+ if current_depth > 1 and len(records) < min_degree:
913
+ return
914
+
915
+ # Add current node to result
916
+ result.nodes.append(node)
917
+ visited_nodes.add(node.id)
918
+
919
+ # Add edge to result if it exists and not already added
920
+ if edge and edge.id not in visited_edges:
921
+ result.edges.append(edge)
922
+ visited_edges.add(edge.id)
923
+
924
+ # Prepare nodes and edges for recursive processing
925
+ nodes_to_process = []
926
+ for record in records:
927
+ rel = record["r"]
928
+ edge_id = str(record["edge_id"])
929
+ if edge_id not in visited_edges:
930
+ b_node = record["b"]
931
+ target_id = str(record["target_id"])
932
+
933
+ if b_node.labels: # Only process if target node has labels
934
+ # Create KnowledgeGraphNode for target
935
+ target_node = KnowledgeGraphNode(
936
+ id=f"{target_id}",
937
+ labels=list(b_node.labels),
938
+ properties=dict(b_node),
939
+ )
940
+
941
+ # Create KnowledgeGraphEdge
942
+ target_edge = KnowledgeGraphEdge(
943
+ id=f"{edge_id}",
944
+ type=rel.type,
945
+ source=f"{node.id}",
946
+ target=f"{target_id}",
947
+ properties=dict(rel),
948
+ )
949
+
950
+ nodes_to_process.append((target_node, target_edge))
951
+ else:
952
+ logger.warning(
953
+ f"Skipping edge {edge_id} due to missing labels on target node"
954
+ )
955
+
956
+ # Process nodes after releasing database connection
957
+ for target_node, target_edge in nodes_to_process:
958
+ await traverse(target_node, target_edge, current_depth + 1)
959
 
960
+ # Get the starting node's data
961
+ async with self._driver.session(
962
+ database=self._DATABASE, default_access_mode="READ"
963
+ ) as session:
964
  query = f"""
965
+ MATCH (n:`{label}`)
966
+ RETURN id(n) as node_id, n
 
 
967
  """
968
+ node_result = await session.run(query)
969
+ try:
970
+ node_record = await node_result.single()
971
+ if not node_record:
972
+ return result
973
+
974
+ # Create initial KnowledgeGraphNode
975
+ start_node = KnowledgeGraphNode(
976
+ id=f"{node_record['node_id']}",
977
+ labels=list(node_record["n"].labels),
978
+ properties=dict(node_record["n"]),
979
+ )
980
+ finally:
981
+ await node_result.consume() # Ensure results are consumed
982
+
983
+ # Start traversal with the initial node
984
+ await traverse(start_node, None, 0)
 
 
 
 
 
 
 
 
 
985
 
 
986
  return result
987
 
988
  async def get_all_labels(self) -> list[str]:
 
991
  Returns:
992
  ["Person", "Company", ...] # Alphabetically sorted label list
993
  """
994
+ async with self._driver.session(
995
+ database=self._DATABASE, default_access_mode="READ"
996
+ ) as session:
997
  # Method 1: Direct metadata query (Available for Neo4j 4.3+)
998
  # query = "CALL db.labels() YIELD label RETURN label"
999
 
 
1005
  RETURN DISTINCT label
1006
  ORDER BY label
1007
  """
 
1008
  result = await session.run(query)
1009
  labels = []
1010
+ try:
1011
+ async for record in result:
1012
+ labels.append(record["label"])
1013
+ finally:
1014
+ await (
1015
+ result.consume()
1016
+ ) # Ensure results are consumed even if processing fails
1017
  return labels
1018
 
1019
  @retry(
 
1034
  Args:
1035
  node_id: The label of the node to delete
1036
  """
1037
+ label = self._ensure_label(node_id)
1038
 
1039
  async def _do_delete(tx: AsyncManagedTransaction):
1040
  query = f"""
1041
  MATCH (n:`{label}`)
1042
  DETACH DELETE n
1043
  """
1044
+ result = await tx.run(query)
1045
  logger.debug(f"Deleted node with label '{label}'")
1046
+ await result.consume() # Ensure result is fully consumed
1047
 
1048
  try:
1049
  async with self._driver.session(database=self._DATABASE) as session:
 
1092
  edges: List of edges to be deleted, each edge is a (source, target) tuple
1093
  """
1094
  for source, target in edges:
1095
+ source_label = self._ensure_label(source)
1096
+ target_label = self._ensure_label(target)
1097
 
1098
  async def _do_delete_edge(tx: AsyncManagedTransaction):
1099
  query = f"""
1100
+ MATCH (source:`{source_label}`)-[r]-(target:`{target_label}`)
1101
  DELETE r
1102
  """
1103
+ result = await tx.run(query)
1104
  logger.debug(f"Deleted edge from '{source_label}' to '{target_label}'")
1105
+ await result.consume() # Ensure result is fully consumed
1106
 
1107
  try:
1108
  async with self._driver.session(database=self._DATABASE) as session:
lightrag/kg/shared_storage.py CHANGED
@@ -7,12 +7,18 @@ from typing import Any, Dict, Optional, Union, TypeVar, Generic
7
 
8
 
9
  # Define a direct print function for critical logs that must be visible in all processes
10
- def direct_log(message, level="INFO"):
11
  """
12
  Log a message directly to stderr to ensure visibility in all processes,
13
  including the Gunicorn master process.
 
 
 
 
 
14
  """
15
- print(f"{level}: {message}", file=sys.stderr, flush=True)
 
16
 
17
 
18
  T = TypeVar("T")
@@ -32,55 +38,165 @@ _update_flags: Optional[Dict[str, bool]] = None # namespace -> updated
32
  _storage_lock: Optional[LockType] = None
33
  _internal_lock: Optional[LockType] = None
34
  _pipeline_status_lock: Optional[LockType] = None
 
 
35
 
36
 
37
  class UnifiedLock(Generic[T]):
38
  """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
39
 
40
- def __init__(self, lock: Union[ProcessLock, asyncio.Lock], is_async: bool):
 
 
 
 
 
 
41
  self._lock = lock
42
  self._is_async = is_async
 
 
 
43
 
44
  async def __aenter__(self) -> "UnifiedLock[T]":
45
- if self._is_async:
46
- await self._lock.acquire()
47
- else:
48
- self._lock.acquire()
49
- return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  async def __aexit__(self, exc_type, exc_val, exc_tb):
52
- if self._is_async:
53
- self._lock.release()
54
- else:
55
- self._lock.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def __enter__(self) -> "UnifiedLock[T]":
58
  """For backward compatibility"""
59
- if self._is_async:
60
- raise RuntimeError("Use 'async with' for shared_storage lock")
61
- self._lock.acquire()
62
- return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def __exit__(self, exc_type, exc_val, exc_tb):
65
  """For backward compatibility"""
66
- if self._is_async:
67
- raise RuntimeError("Use 'async with' for shared_storage lock")
68
- self._lock.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
 
71
- def get_internal_lock() -> UnifiedLock:
72
  """return unified storage lock for data consistency"""
73
- return UnifiedLock(lock=_internal_lock, is_async=not is_multiprocess)
 
 
 
 
 
74
 
75
 
76
- def get_storage_lock() -> UnifiedLock:
77
  """return unified storage lock for data consistency"""
78
- return UnifiedLock(lock=_storage_lock, is_async=not is_multiprocess)
 
 
 
 
 
79
 
80
 
81
- def get_pipeline_status_lock() -> UnifiedLock:
82
  """return unified storage lock for data consistency"""
83
- return UnifiedLock(lock=_pipeline_status_lock, is_async=not is_multiprocess)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
 
86
  def initialize_share_data(workers: int = 1):
@@ -108,6 +224,8 @@ def initialize_share_data(workers: int = 1):
108
  _storage_lock, \
109
  _internal_lock, \
110
  _pipeline_status_lock, \
 
 
111
  _shared_dicts, \
112
  _init_flags, \
113
  _initialized, \
@@ -120,14 +238,16 @@ def initialize_share_data(workers: int = 1):
120
  )
121
  return
122
 
123
- _manager = Manager()
124
  _workers = workers
125
 
126
  if workers > 1:
127
  is_multiprocess = True
 
128
  _internal_lock = _manager.Lock()
129
  _storage_lock = _manager.Lock()
130
  _pipeline_status_lock = _manager.Lock()
 
 
131
  _shared_dicts = _manager.dict()
132
  _init_flags = _manager.dict()
133
  _update_flags = _manager.dict()
@@ -139,6 +259,8 @@ def initialize_share_data(workers: int = 1):
139
  _internal_lock = asyncio.Lock()
140
  _storage_lock = asyncio.Lock()
141
  _pipeline_status_lock = asyncio.Lock()
 
 
142
  _shared_dicts = {}
143
  _init_flags = {}
144
  _update_flags = {}
@@ -164,6 +286,7 @@ async def initialize_pipeline_status():
164
  history_messages = _manager.list() if is_multiprocess else []
165
  pipeline_namespace.update(
166
  {
 
167
  "busy": False, # Control concurrent processes
168
  "job_name": "Default Job", # Current job name (indexing files/indexing texts)
169
  "job_start": None, # Job start time
@@ -200,7 +323,12 @@ async def get_update_flag(namespace: str):
200
  if is_multiprocess and _manager is not None:
201
  new_update_flag = _manager.Value("b", False)
202
  else:
203
- new_update_flag = False
 
 
 
 
 
204
 
205
  _update_flags[namespace].append(new_update_flag)
206
  return new_update_flag
@@ -220,7 +348,26 @@ async def set_all_update_flags(namespace: str):
220
  if is_multiprocess:
221
  _update_flags[namespace][i].value = True
222
  else:
223
- _update_flags[namespace][i] = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
 
226
  async def get_all_update_flags_status() -> Dict[str, list]:
@@ -247,7 +394,7 @@ async def get_all_update_flags_status() -> Dict[str, list]:
247
  return result
248
 
249
 
250
- def try_initialize_namespace(namespace: str) -> bool:
251
  """
252
  Returns True if the current worker(process) gets initialization permission for loading data later.
253
  The worker does not get the permission is prohibited to load data from files.
@@ -257,15 +404,17 @@ def try_initialize_namespace(namespace: str) -> bool:
257
  if _init_flags is None:
258
  raise ValueError("Try to create nanmespace before Shared-Data is initialized")
259
 
260
- if namespace not in _init_flags:
261
- _init_flags[namespace] = True
 
 
 
 
 
262
  direct_log(
263
- f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
264
  )
265
- return True
266
- direct_log(
267
- f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
268
- )
269
  return False
270
 
271
 
@@ -304,6 +453,8 @@ def finalize_share_data():
304
  _storage_lock, \
305
  _internal_lock, \
306
  _pipeline_status_lock, \
 
 
307
  _shared_dicts, \
308
  _init_flags, \
309
  _initialized, \
@@ -369,6 +520,8 @@ def finalize_share_data():
369
  _storage_lock = None
370
  _internal_lock = None
371
  _pipeline_status_lock = None
 
 
372
  _update_flags = None
373
 
374
  direct_log(f"Process {os.getpid()} storage data finalization complete")
 
7
 
8
 
9
  # Define a direct print function for critical logs that must be visible in all processes
10
+ def direct_log(message, level="INFO", enable_output: bool = True):
11
  """
12
  Log a message directly to stderr to ensure visibility in all processes,
13
  including the Gunicorn master process.
14
+
15
+ Args:
16
+ message: The message to log
17
+ level: Log level (default: "INFO")
18
+ enable_output: Whether to actually output the log (default: True)
19
  """
20
+ if enable_output:
21
+ print(f"{level}: {message}", file=sys.stderr, flush=True)
22
 
23
 
24
  T = TypeVar("T")
 
38
  _storage_lock: Optional[LockType] = None
39
  _internal_lock: Optional[LockType] = None
40
  _pipeline_status_lock: Optional[LockType] = None
41
+ _graph_db_lock: Optional[LockType] = None
42
+ _data_init_lock: Optional[LockType] = None
43
 
44
 
45
  class UnifiedLock(Generic[T]):
46
  """Provide a unified lock interface type for asyncio.Lock and multiprocessing.Lock"""
47
 
48
+ def __init__(
49
+ self,
50
+ lock: Union[ProcessLock, asyncio.Lock],
51
+ is_async: bool,
52
+ name: str = "unnamed",
53
+ enable_logging: bool = True,
54
+ ):
55
  self._lock = lock
56
  self._is_async = is_async
57
+ self._pid = os.getpid() # for debug only
58
+ self._name = name # for debug only
59
+ self._enable_logging = enable_logging # for debug only
60
 
61
  async def __aenter__(self) -> "UnifiedLock[T]":
62
+ try:
63
+ direct_log(
64
+ f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (async={self._is_async})",
65
+ enable_output=self._enable_logging,
66
+ )
67
+ if self._is_async:
68
+ await self._lock.acquire()
69
+ else:
70
+ self._lock.acquire()
71
+ direct_log(
72
+ f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (async={self._is_async})",
73
+ enable_output=self._enable_logging,
74
+ )
75
+ return self
76
+ except Exception as e:
77
+ direct_log(
78
+ f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}': {e}",
79
+ level="ERROR",
80
+ enable_output=self._enable_logging,
81
+ )
82
+ raise
83
 
84
  async def __aexit__(self, exc_type, exc_val, exc_tb):
85
+ try:
86
+ direct_log(
87
+ f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (async={self._is_async})",
88
+ enable_output=self._enable_logging,
89
+ )
90
+ if self._is_async:
91
+ self._lock.release()
92
+ else:
93
+ self._lock.release()
94
+ direct_log(
95
+ f"== Lock == Process {self._pid}: Lock '{self._name}' released (async={self._is_async})",
96
+ enable_output=self._enable_logging,
97
+ )
98
+ except Exception as e:
99
+ direct_log(
100
+ f"== Lock == Process {self._pid}: Failed to release lock '{self._name}': {e}",
101
+ level="ERROR",
102
+ enable_output=self._enable_logging,
103
+ )
104
+ raise
105
 
106
  def __enter__(self) -> "UnifiedLock[T]":
107
  """For backward compatibility"""
108
+ try:
109
+ if self._is_async:
110
+ raise RuntimeError("Use 'async with' for shared_storage lock")
111
+ direct_log(
112
+ f"== Lock == Process {self._pid}: Acquiring lock '{self._name}' (sync)",
113
+ enable_output=self._enable_logging,
114
+ )
115
+ self._lock.acquire()
116
+ direct_log(
117
+ f"== Lock == Process {self._pid}: Lock '{self._name}' acquired (sync)",
118
+ enable_output=self._enable_logging,
119
+ )
120
+ return self
121
+ except Exception as e:
122
+ direct_log(
123
+ f"== Lock == Process {self._pid}: Failed to acquire lock '{self._name}' (sync): {e}",
124
+ level="ERROR",
125
+ enable_output=self._enable_logging,
126
+ )
127
+ raise
128
 
129
  def __exit__(self, exc_type, exc_val, exc_tb):
130
  """For backward compatibility"""
131
+ try:
132
+ if self._is_async:
133
+ raise RuntimeError("Use 'async with' for shared_storage lock")
134
+ direct_log(
135
+ f"== Lock == Process {self._pid}: Releasing lock '{self._name}' (sync)",
136
+ enable_output=self._enable_logging,
137
+ )
138
+ self._lock.release()
139
+ direct_log(
140
+ f"== Lock == Process {self._pid}: Lock '{self._name}' released (sync)",
141
+ enable_output=self._enable_logging,
142
+ )
143
+ except Exception as e:
144
+ direct_log(
145
+ f"== Lock == Process {self._pid}: Failed to release lock '{self._name}' (sync): {e}",
146
+ level="ERROR",
147
+ enable_output=self._enable_logging,
148
+ )
149
+ raise
150
 
151
 
152
+ def get_internal_lock(enable_logging: bool = False) -> UnifiedLock:
153
  """return unified storage lock for data consistency"""
154
+ return UnifiedLock(
155
+ lock=_internal_lock,
156
+ is_async=not is_multiprocess,
157
+ name="internal_lock",
158
+ enable_logging=enable_logging,
159
+ )
160
 
161
 
162
+ def get_storage_lock(enable_logging: bool = False) -> UnifiedLock:
163
  """return unified storage lock for data consistency"""
164
+ return UnifiedLock(
165
+ lock=_storage_lock,
166
+ is_async=not is_multiprocess,
167
+ name="storage_lock",
168
+ enable_logging=enable_logging,
169
+ )
170
 
171
 
172
+ def get_pipeline_status_lock(enable_logging: bool = False) -> UnifiedLock:
173
  """return unified storage lock for data consistency"""
174
+ return UnifiedLock(
175
+ lock=_pipeline_status_lock,
176
+ is_async=not is_multiprocess,
177
+ name="pipeline_status_lock",
178
+ enable_logging=enable_logging,
179
+ )
180
+
181
+
182
+ def get_graph_db_lock(enable_logging: bool = False) -> UnifiedLock:
183
+ """return unified graph database lock for ensuring atomic operations"""
184
+ return UnifiedLock(
185
+ lock=_graph_db_lock,
186
+ is_async=not is_multiprocess,
187
+ name="graph_db_lock",
188
+ enable_logging=enable_logging,
189
+ )
190
+
191
+
192
+ def get_data_init_lock(enable_logging: bool = False) -> UnifiedLock:
193
+ """return unified data initialization lock for ensuring atomic data initialization"""
194
+ return UnifiedLock(
195
+ lock=_data_init_lock,
196
+ is_async=not is_multiprocess,
197
+ name="data_init_lock",
198
+ enable_logging=enable_logging,
199
+ )
200
 
201
 
202
  def initialize_share_data(workers: int = 1):
 
224
  _storage_lock, \
225
  _internal_lock, \
226
  _pipeline_status_lock, \
227
+ _graph_db_lock, \
228
+ _data_init_lock, \
229
  _shared_dicts, \
230
  _init_flags, \
231
  _initialized, \
 
238
  )
239
  return
240
 
 
241
  _workers = workers
242
 
243
  if workers > 1:
244
  is_multiprocess = True
245
+ _manager = Manager()
246
  _internal_lock = _manager.Lock()
247
  _storage_lock = _manager.Lock()
248
  _pipeline_status_lock = _manager.Lock()
249
+ _graph_db_lock = _manager.Lock()
250
+ _data_init_lock = _manager.Lock()
251
  _shared_dicts = _manager.dict()
252
  _init_flags = _manager.dict()
253
  _update_flags = _manager.dict()
 
259
  _internal_lock = asyncio.Lock()
260
  _storage_lock = asyncio.Lock()
261
  _pipeline_status_lock = asyncio.Lock()
262
+ _graph_db_lock = asyncio.Lock()
263
+ _data_init_lock = asyncio.Lock()
264
  _shared_dicts = {}
265
  _init_flags = {}
266
  _update_flags = {}
 
286
  history_messages = _manager.list() if is_multiprocess else []
287
  pipeline_namespace.update(
288
  {
289
+ "autoscanned": False, # Auto-scan started
290
  "busy": False, # Control concurrent processes
291
  "job_name": "Default Job", # Current job name (indexing files/indexing texts)
292
  "job_start": None, # Job start time
 
323
  if is_multiprocess and _manager is not None:
324
  new_update_flag = _manager.Value("b", False)
325
  else:
326
+ # Create a simple mutable object to store boolean value for compatibility with mutiprocess
327
+ class MutableBoolean:
328
+ def __init__(self, initial_value=False):
329
+ self.value = initial_value
330
+
331
+ new_update_flag = MutableBoolean(False)
332
 
333
  _update_flags[namespace].append(new_update_flag)
334
  return new_update_flag
 
348
  if is_multiprocess:
349
  _update_flags[namespace][i].value = True
350
  else:
351
+ # Use .value attribute instead of direct assignment
352
+ _update_flags[namespace][i].value = True
353
+
354
+
355
+ async def clear_all_update_flags(namespace: str):
356
+ """Clear all update flag of namespace indicating all workers need to reload data from files"""
357
+ global _update_flags
358
+ if _update_flags is None:
359
+ raise ValueError("Try to create namespace before Shared-Data is initialized")
360
+
361
+ async with get_internal_lock():
362
+ if namespace not in _update_flags:
363
+ raise ValueError(f"Namespace {namespace} not found in update flags")
364
+ # Update flags for both modes
365
+ for i in range(len(_update_flags[namespace])):
366
+ if is_multiprocess:
367
+ _update_flags[namespace][i].value = False
368
+ else:
369
+ # Use .value attribute instead of direct assignment
370
+ _update_flags[namespace][i].value = False
371
 
372
 
373
  async def get_all_update_flags_status() -> Dict[str, list]:
 
394
  return result
395
 
396
 
397
+ async def try_initialize_namespace(namespace: str) -> bool:
398
  """
399
  Returns True if the current worker(process) gets initialization permission for loading data later.
400
  The worker does not get the permission is prohibited to load data from files.
 
404
  if _init_flags is None:
405
  raise ValueError("Try to create nanmespace before Shared-Data is initialized")
406
 
407
+ async with get_internal_lock():
408
+ if namespace not in _init_flags:
409
+ _init_flags[namespace] = True
410
+ direct_log(
411
+ f"Process {os.getpid()} ready to initialize storage namespace: [{namespace}]"
412
+ )
413
+ return True
414
  direct_log(
415
+ f"Process {os.getpid()} storage namespace already initialized: [{namespace}]"
416
  )
417
+
 
 
 
418
  return False
419
 
420
 
 
453
  _storage_lock, \
454
  _internal_lock, \
455
  _pipeline_status_lock, \
456
+ _graph_db_lock, \
457
+ _data_init_lock, \
458
  _shared_dicts, \
459
  _init_flags, \
460
  _initialized, \
 
520
  _storage_lock = None
521
  _internal_lock = None
522
  _pipeline_status_lock = None
523
+ _graph_db_lock = None
524
+ _data_init_lock = None
525
  _update_flags = None
526
 
527
  direct_log(f"Process {os.getpid()} storage data finalization complete")
lightrag/lightrag.py CHANGED
@@ -354,6 +354,9 @@ class LightRAG:
354
  namespace=make_namespace(
355
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
356
  ),
 
 
 
357
  embedding_func=self.embedding_func,
358
  )
359
 
@@ -404,18 +407,8 @@ class LightRAG:
404
  embedding_func=None,
405
  )
406
 
407
- if self.llm_response_cache and hasattr(
408
- self.llm_response_cache, "global_config"
409
- ):
410
- hashing_kv = self.llm_response_cache
411
- else:
412
- hashing_kv = self.key_string_value_json_storage_cls( # type: ignore
413
- namespace=make_namespace(
414
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
415
- ),
416
- global_config=asdict(self),
417
- embedding_func=self.embedding_func,
418
- )
419
 
420
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
421
  partial(
@@ -590,6 +583,7 @@ class LightRAG:
590
  split_by_character, split_by_character_only
591
  )
592
 
 
593
  def insert_custom_chunks(
594
  self,
595
  full_text: str,
@@ -601,6 +595,7 @@ class LightRAG:
601
  self.ainsert_custom_chunks(full_text, text_chunks, doc_id)
602
  )
603
 
 
604
  async def ainsert_custom_chunks(
605
  self, full_text: str, text_chunks: list[str], doc_id: str | None = None
606
  ) -> None:
@@ -892,7 +887,9 @@ class LightRAG:
892
  self.chunks_vdb.upsert(chunks)
893
  )
894
  entity_relation_task = asyncio.create_task(
895
- self._process_entity_relation_graph(chunks)
 
 
896
  )
897
  full_docs_task = asyncio.create_task(
898
  self.full_docs.upsert(
@@ -1007,21 +1004,27 @@ class LightRAG:
1007
  pipeline_status["latest_message"] = log_message
1008
  pipeline_status["history_messages"].append(log_message)
1009
 
1010
- async def _process_entity_relation_graph(self, chunk: dict[str, Any]) -> None:
 
 
1011
  try:
1012
  await extract_entities(
1013
  chunk,
1014
  knowledge_graph_inst=self.chunk_entity_relation_graph,
1015
  entity_vdb=self.entities_vdb,
1016
  relationships_vdb=self.relationships_vdb,
1017
- llm_response_cache=self.llm_response_cache,
1018
  global_config=asdict(self),
 
 
 
1019
  )
1020
  except Exception as e:
1021
  logger.error("Failed to extract entities and relationships")
1022
  raise e
1023
 
1024
- async def _insert_done(self) -> None:
 
 
1025
  tasks = [
1026
  cast(StorageNameSpace, storage_inst).index_done_callback()
1027
  for storage_inst in [ # type: ignore
@@ -1040,12 +1043,10 @@ class LightRAG:
1040
  log_message = "All Insert done"
1041
  logger.info(log_message)
1042
 
1043
- # 获取 pipeline_status 并更新 latest_message history_messages
1044
- from lightrag.kg.shared_storage import get_namespace_data
1045
-
1046
- pipeline_status = await get_namespace_data("pipeline_status")
1047
- pipeline_status["latest_message"] = log_message
1048
- pipeline_status["history_messages"].append(log_message)
1049
 
1050
  def insert_custom_kg(
1051
  self, custom_kg: dict[str, Any], full_doc_id: str = None
@@ -1260,16 +1261,7 @@ class LightRAG:
1260
  self.text_chunks,
1261
  param,
1262
  asdict(self),
1263
- hashing_kv=self.llm_response_cache
1264
- if self.llm_response_cache
1265
- and hasattr(self.llm_response_cache, "global_config")
1266
- else self.key_string_value_json_storage_cls(
1267
- namespace=make_namespace(
1268
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1269
- ),
1270
- global_config=asdict(self),
1271
- embedding_func=self.embedding_func,
1272
- ),
1273
  system_prompt=system_prompt,
1274
  )
1275
  elif param.mode == "naive":
@@ -1279,16 +1271,7 @@ class LightRAG:
1279
  self.text_chunks,
1280
  param,
1281
  asdict(self),
1282
- hashing_kv=self.llm_response_cache
1283
- if self.llm_response_cache
1284
- and hasattr(self.llm_response_cache, "global_config")
1285
- else self.key_string_value_json_storage_cls(
1286
- namespace=make_namespace(
1287
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1288
- ),
1289
- global_config=asdict(self),
1290
- embedding_func=self.embedding_func,
1291
- ),
1292
  system_prompt=system_prompt,
1293
  )
1294
  elif param.mode == "mix":
@@ -1301,16 +1284,7 @@ class LightRAG:
1301
  self.text_chunks,
1302
  param,
1303
  asdict(self),
1304
- hashing_kv=self.llm_response_cache
1305
- if self.llm_response_cache
1306
- and hasattr(self.llm_response_cache, "global_config")
1307
- else self.key_string_value_json_storage_cls(
1308
- namespace=make_namespace(
1309
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1310
- ),
1311
- global_config=asdict(self),
1312
- embedding_func=self.embedding_func,
1313
- ),
1314
  system_prompt=system_prompt,
1315
  )
1316
  else:
@@ -1344,14 +1318,7 @@ class LightRAG:
1344
  text=query,
1345
  param=param,
1346
  global_config=asdict(self),
1347
- hashing_kv=self.llm_response_cache
1348
- or self.key_string_value_json_storage_cls(
1349
- namespace=make_namespace(
1350
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1351
- ),
1352
- global_config=asdict(self),
1353
- embedding_func=self.embedding_func,
1354
- ),
1355
  )
1356
 
1357
  param.hl_keywords = hl_keywords
@@ -1375,16 +1342,7 @@ class LightRAG:
1375
  self.text_chunks,
1376
  param,
1377
  asdict(self),
1378
- hashing_kv=self.llm_response_cache
1379
- if self.llm_response_cache
1380
- and hasattr(self.llm_response_cache, "global_config")
1381
- else self.key_string_value_json_storage_cls(
1382
- namespace=make_namespace(
1383
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1384
- ),
1385
- global_config=asdict(self),
1386
- embedding_func=self.embedding_func,
1387
- ),
1388
  )
1389
  elif param.mode == "naive":
1390
  response = await naive_query(
@@ -1393,16 +1351,7 @@ class LightRAG:
1393
  self.text_chunks,
1394
  param,
1395
  asdict(self),
1396
- hashing_kv=self.llm_response_cache
1397
- if self.llm_response_cache
1398
- and hasattr(self.llm_response_cache, "global_config")
1399
- else self.key_string_value_json_storage_cls(
1400
- namespace=make_namespace(
1401
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1402
- ),
1403
- global_config=asdict(self),
1404
- embedding_func=self.embedding_func,
1405
- ),
1406
  )
1407
  elif param.mode == "mix":
1408
  response = await mix_kg_vector_query(
@@ -1414,16 +1363,7 @@ class LightRAG:
1414
  self.text_chunks,
1415
  param,
1416
  asdict(self),
1417
- hashing_kv=self.llm_response_cache
1418
- if self.llm_response_cache
1419
- and hasattr(self.llm_response_cache, "global_config")
1420
- else self.key_string_value_json_storage_cls(
1421
- namespace=make_namespace(
1422
- self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
1423
- ),
1424
- global_config=asdict(self),
1425
- embedding_func=self.embedding_func,
1426
- ),
1427
  )
1428
  else:
1429
  raise ValueError(f"Unknown mode {param.mode}")
 
354
  namespace=make_namespace(
355
  self.namespace_prefix, NameSpace.KV_STORE_LLM_RESPONSE_CACHE
356
  ),
357
+ global_config=asdict(
358
+ self
359
+ ), # Add global_config to ensure cache works properly
360
  embedding_func=self.embedding_func,
361
  )
362
 
 
407
  embedding_func=None,
408
  )
409
 
410
+ # Directly use llm_response_cache, don't create a new object
411
+ hashing_kv = self.llm_response_cache
 
 
 
 
 
 
 
 
 
 
412
 
413
  self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
414
  partial(
 
583
  split_by_character, split_by_character_only
584
  )
585
 
586
+ # TODO: deprecated, use insert instead
587
  def insert_custom_chunks(
588
  self,
589
  full_text: str,
 
595
  self.ainsert_custom_chunks(full_text, text_chunks, doc_id)
596
  )
597
 
598
+ # TODO: deprecated, use ainsert instead
599
  async def ainsert_custom_chunks(
600
  self, full_text: str, text_chunks: list[str], doc_id: str | None = None
601
  ) -> None:
 
887
  self.chunks_vdb.upsert(chunks)
888
  )
889
  entity_relation_task = asyncio.create_task(
890
+ self._process_entity_relation_graph(
891
+ chunks, pipeline_status, pipeline_status_lock
892
+ )
893
  )
894
  full_docs_task = asyncio.create_task(
895
  self.full_docs.upsert(
 
1004
  pipeline_status["latest_message"] = log_message
1005
  pipeline_status["history_messages"].append(log_message)
1006
 
1007
+ async def _process_entity_relation_graph(
1008
+ self, chunk: dict[str, Any], pipeline_status=None, pipeline_status_lock=None
1009
+ ) -> None:
1010
  try:
1011
  await extract_entities(
1012
  chunk,
1013
  knowledge_graph_inst=self.chunk_entity_relation_graph,
1014
  entity_vdb=self.entities_vdb,
1015
  relationships_vdb=self.relationships_vdb,
 
1016
  global_config=asdict(self),
1017
+ pipeline_status=pipeline_status,
1018
+ pipeline_status_lock=pipeline_status_lock,
1019
+ llm_response_cache=self.llm_response_cache,
1020
  )
1021
  except Exception as e:
1022
  logger.error("Failed to extract entities and relationships")
1023
  raise e
1024
 
1025
+ async def _insert_done(
1026
+ self, pipeline_status=None, pipeline_status_lock=None
1027
+ ) -> None:
1028
  tasks = [
1029
  cast(StorageNameSpace, storage_inst).index_done_callback()
1030
  for storage_inst in [ # type: ignore
 
1043
  log_message = "All Insert done"
1044
  logger.info(log_message)
1045
 
1046
+ if pipeline_status is not None and pipeline_status_lock is not None:
1047
+ async with pipeline_status_lock:
1048
+ pipeline_status["latest_message"] = log_message
1049
+ pipeline_status["history_messages"].append(log_message)
 
 
1050
 
1051
  def insert_custom_kg(
1052
  self, custom_kg: dict[str, Any], full_doc_id: str = None
 
1261
  self.text_chunks,
1262
  param,
1263
  asdict(self),
1264
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
 
 
1265
  system_prompt=system_prompt,
1266
  )
1267
  elif param.mode == "naive":
 
1271
  self.text_chunks,
1272
  param,
1273
  asdict(self),
1274
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
 
 
1275
  system_prompt=system_prompt,
1276
  )
1277
  elif param.mode == "mix":
 
1284
  self.text_chunks,
1285
  param,
1286
  asdict(self),
1287
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
 
 
1288
  system_prompt=system_prompt,
1289
  )
1290
  else:
 
1318
  text=query,
1319
  param=param,
1320
  global_config=asdict(self),
1321
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
1322
  )
1323
 
1324
  param.hl_keywords = hl_keywords
 
1342
  self.text_chunks,
1343
  param,
1344
  asdict(self),
1345
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
 
 
1346
  )
1347
  elif param.mode == "naive":
1348
  response = await naive_query(
 
1351
  self.text_chunks,
1352
  param,
1353
  asdict(self),
1354
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
 
 
1355
  )
1356
  elif param.mode == "mix":
1357
  response = await mix_kg_vector_query(
 
1363
  self.text_chunks,
1364
  param,
1365
  asdict(self),
1366
+ hashing_kv=self.llm_response_cache, # Directly use llm_response_cache
 
 
 
 
 
 
 
 
 
1367
  )
1368
  else:
1369
  raise ValueError(f"Unknown mode {param.mode}")
lightrag/operate.py CHANGED
@@ -3,6 +3,7 @@ from __future__ import annotations
3
  import asyncio
4
  import json
5
  import re
 
6
  from typing import Any, AsyncIterator
7
  from collections import Counter, defaultdict
8
 
@@ -220,6 +221,7 @@ async def _merge_nodes_then_upsert(
220
  entity_name, description, global_config
221
  )
222
  node_data = dict(
 
223
  entity_type=entity_type,
224
  description=description,
225
  source_id=source_id,
@@ -301,6 +303,7 @@ async def _merge_edges_then_upsert(
301
  await knowledge_graph_inst.upsert_node(
302
  need_insert_id,
303
  node_data={
 
304
  "source_id": source_id,
305
  "description": description,
306
  "entity_type": "UNKNOWN",
@@ -337,11 +340,10 @@ async def extract_entities(
337
  entity_vdb: BaseVectorStorage,
338
  relationships_vdb: BaseVectorStorage,
339
  global_config: dict[str, str],
 
 
340
  llm_response_cache: BaseKVStorage | None = None,
341
  ) -> None:
342
- from lightrag.kg.shared_storage import get_namespace_data
343
-
344
- pipeline_status = await get_namespace_data("pipeline_status")
345
  use_llm_func: callable = global_config["llm_model_func"]
346
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
347
  enable_llm_cache_for_entity_extract: bool = global_config[
@@ -400,6 +402,7 @@ async def extract_entities(
400
  else:
401
  _prompt = input_text
402
 
 
403
  arg_hash = compute_args_hash(_prompt)
404
  cached_return, _1, _2, _3 = await handle_cache(
405
  llm_response_cache,
@@ -407,7 +410,6 @@ async def extract_entities(
407
  _prompt,
408
  "default",
409
  cache_type="extract",
410
- force_llm_cache=True,
411
  )
412
  if cached_return:
413
  logger.debug(f"Found cache for {arg_hash}")
@@ -504,8 +506,10 @@ async def extract_entities(
504
  relations_count = len(maybe_edges)
505
  log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
506
  logger.info(log_message)
507
- pipeline_status["latest_message"] = log_message
508
- pipeline_status["history_messages"].append(log_message)
 
 
509
  return dict(maybe_nodes), dict(maybe_edges)
510
 
511
  tasks = [_process_single_content(c) for c in ordered_chunks]
@@ -519,42 +523,58 @@ async def extract_entities(
519
  for k, v in m_edges.items():
520
  maybe_edges[tuple(sorted(k))].extend(v)
521
 
522
- all_entities_data = await asyncio.gather(
523
- *[
524
- _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
525
- for k, v in maybe_nodes.items()
526
- ]
527
- )
528
 
529
- all_relationships_data = await asyncio.gather(
530
- *[
531
- _merge_edges_then_upsert(k[0], k[1], v, knowledge_graph_inst, global_config)
532
- for k, v in maybe_edges.items()
533
- ]
534
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
535
 
536
  if not (all_entities_data or all_relationships_data):
537
  log_message = "Didn't extract any entities and relationships."
538
  logger.info(log_message)
539
- pipeline_status["latest_message"] = log_message
540
- pipeline_status["history_messages"].append(log_message)
 
 
541
  return
542
 
543
  if not all_entities_data:
544
  log_message = "Didn't extract any entities"
545
  logger.info(log_message)
546
- pipeline_status["latest_message"] = log_message
547
- pipeline_status["history_messages"].append(log_message)
 
 
548
  if not all_relationships_data:
549
  log_message = "Didn't extract any relationships"
550
  logger.info(log_message)
551
- pipeline_status["latest_message"] = log_message
552
- pipeline_status["history_messages"].append(log_message)
 
 
553
 
554
  log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
555
  logger.info(log_message)
556
- pipeline_status["latest_message"] = log_message
557
- pipeline_status["history_messages"].append(log_message)
 
 
558
  verbose_debug(
559
  f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
560
  )
@@ -1017,6 +1037,7 @@ async def _build_query_context(
1017
  text_chunks_db: BaseKVStorage,
1018
  query_param: QueryParam,
1019
  ):
 
1020
  if query_param.mode == "local":
1021
  entities_context, relations_context, text_units_context = await _get_node_data(
1022
  ll_keywords,
 
3
  import asyncio
4
  import json
5
  import re
6
+ import os
7
  from typing import Any, AsyncIterator
8
  from collections import Counter, defaultdict
9
 
 
221
  entity_name, description, global_config
222
  )
223
  node_data = dict(
224
+ entity_id=entity_name,
225
  entity_type=entity_type,
226
  description=description,
227
  source_id=source_id,
 
303
  await knowledge_graph_inst.upsert_node(
304
  need_insert_id,
305
  node_data={
306
+ "entity_id": need_insert_id,
307
  "source_id": source_id,
308
  "description": description,
309
  "entity_type": "UNKNOWN",
 
340
  entity_vdb: BaseVectorStorage,
341
  relationships_vdb: BaseVectorStorage,
342
  global_config: dict[str, str],
343
+ pipeline_status: dict = None,
344
+ pipeline_status_lock=None,
345
  llm_response_cache: BaseKVStorage | None = None,
346
  ) -> None:
 
 
 
347
  use_llm_func: callable = global_config["llm_model_func"]
348
  entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
349
  enable_llm_cache_for_entity_extract: bool = global_config[
 
402
  else:
403
  _prompt = input_text
404
 
405
+ # TODO: add cache_type="extract"
406
  arg_hash = compute_args_hash(_prompt)
407
  cached_return, _1, _2, _3 = await handle_cache(
408
  llm_response_cache,
 
410
  _prompt,
411
  "default",
412
  cache_type="extract",
 
413
  )
414
  if cached_return:
415
  logger.debug(f"Found cache for {arg_hash}")
 
506
  relations_count = len(maybe_edges)
507
  log_message = f" Chunk {processed_chunks}/{total_chunks}: extracted {entities_count} entities and {relations_count} relationships (deduplicated)"
508
  logger.info(log_message)
509
+ if pipeline_status is not None:
510
+ async with pipeline_status_lock:
511
+ pipeline_status["latest_message"] = log_message
512
+ pipeline_status["history_messages"].append(log_message)
513
  return dict(maybe_nodes), dict(maybe_edges)
514
 
515
  tasks = [_process_single_content(c) for c in ordered_chunks]
 
523
  for k, v in m_edges.items():
524
  maybe_edges[tuple(sorted(k))].extend(v)
525
 
526
+ from .kg.shared_storage import get_graph_db_lock
 
 
 
 
 
527
 
528
+ graph_db_lock = get_graph_db_lock(enable_logging=False)
529
+
530
+ # Ensure that nodes and edges are merged and upserted atomically
531
+ async with graph_db_lock:
532
+ all_entities_data = await asyncio.gather(
533
+ *[
534
+ _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
535
+ for k, v in maybe_nodes.items()
536
+ ]
537
+ )
538
+
539
+ all_relationships_data = await asyncio.gather(
540
+ *[
541
+ _merge_edges_then_upsert(
542
+ k[0], k[1], v, knowledge_graph_inst, global_config
543
+ )
544
+ for k, v in maybe_edges.items()
545
+ ]
546
+ )
547
 
548
  if not (all_entities_data or all_relationships_data):
549
  log_message = "Didn't extract any entities and relationships."
550
  logger.info(log_message)
551
+ if pipeline_status is not None:
552
+ async with pipeline_status_lock:
553
+ pipeline_status["latest_message"] = log_message
554
+ pipeline_status["history_messages"].append(log_message)
555
  return
556
 
557
  if not all_entities_data:
558
  log_message = "Didn't extract any entities"
559
  logger.info(log_message)
560
+ if pipeline_status is not None:
561
+ async with pipeline_status_lock:
562
+ pipeline_status["latest_message"] = log_message
563
+ pipeline_status["history_messages"].append(log_message)
564
  if not all_relationships_data:
565
  log_message = "Didn't extract any relationships"
566
  logger.info(log_message)
567
+ if pipeline_status is not None:
568
+ async with pipeline_status_lock:
569
+ pipeline_status["latest_message"] = log_message
570
+ pipeline_status["history_messages"].append(log_message)
571
 
572
  log_message = f"Extracted {len(all_entities_data)} entities and {len(all_relationships_data)} relationships (deduplicated)"
573
  logger.info(log_message)
574
+ if pipeline_status is not None:
575
+ async with pipeline_status_lock:
576
+ pipeline_status["latest_message"] = log_message
577
+ pipeline_status["history_messages"].append(log_message)
578
  verbose_debug(
579
  f"New entities:{all_entities_data}, relationships:{all_relationships_data}"
580
  )
 
1037
  text_chunks_db: BaseKVStorage,
1038
  query_param: QueryParam,
1039
  ):
1040
+ logger.info(f"Process {os.getpid()} buidling query context...")
1041
  if query_param.mode == "local":
1042
  entities_context, relations_context, text_units_context = await _get_node_data(
1043
  ll_keywords,
lightrag/utils.py CHANGED
@@ -633,15 +633,15 @@ async def handle_cache(
633
  prompt,
634
  mode="default",
635
  cache_type=None,
636
- force_llm_cache=False,
637
  ):
638
  """Generic cache handling function"""
639
- if hashing_kv is None or not (
640
- force_llm_cache or hashing_kv.global_config.get("enable_llm_cache")
641
- ):
642
  return None, None, None, None
643
 
644
- if mode != "default":
 
 
 
645
  # Get embedding cache configuration
646
  embedding_cache_config = hashing_kv.global_config.get(
647
  "embedding_cache_config",
@@ -651,8 +651,7 @@ async def handle_cache(
651
  use_llm_check = embedding_cache_config.get("use_llm_check", False)
652
 
653
  quantized = min_val = max_val = None
654
- if is_embedding_cache_enabled:
655
- # Use embedding cache
656
  current_embedding = await hashing_kv.embedding_func([prompt])
657
  llm_model_func = hashing_kv.global_config.get("llm_model_func")
658
  quantized, min_val, max_val = quantize_embedding(current_embedding[0])
@@ -667,24 +666,29 @@ async def handle_cache(
667
  cache_type=cache_type,
668
  )
669
  if best_cached_response is not None:
670
- logger.info(f"Embedding cached hit(mode:{mode} type:{cache_type})")
671
  return best_cached_response, None, None, None
672
  else:
673
  # if caching keyword embedding is enabled, return the quantized embedding for saving it latter
674
- logger.info(f"Embedding cached missed(mode:{mode} type:{cache_type})")
675
  return None, quantized, min_val, max_val
676
 
677
- # For default mode or is_embedding_cache_enabled is False, use regular cache
678
- # default mode is for extract_entities or naive query
 
 
 
 
 
679
  if exists_func(hashing_kv, "get_by_mode_and_id"):
680
  mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
681
  else:
682
  mode_cache = await hashing_kv.get_by_id(mode) or {}
683
  if args_hash in mode_cache:
684
- logger.info(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
685
  return mode_cache[args_hash]["return"], None, None, None
686
 
687
- logger.info(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
688
  return None, None, None, None
689
 
690
 
@@ -701,9 +705,22 @@ class CacheData:
701
 
702
 
703
  async def save_to_cache(hashing_kv, cache_data: CacheData):
704
- if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
 
 
 
 
 
 
 
705
  return
706
 
 
 
 
 
 
 
707
  if exists_func(hashing_kv, "get_by_mode_and_id"):
708
  mode_cache = (
709
  await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
@@ -712,6 +729,16 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
712
  else:
713
  mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
714
 
 
 
 
 
 
 
 
 
 
 
715
  mode_cache[cache_data.args_hash] = {
716
  "return": cache_data.content,
717
  "cache_type": cache_data.cache_type,
@@ -726,6 +753,7 @@ async def save_to_cache(hashing_kv, cache_data: CacheData):
726
  "original_prompt": cache_data.prompt,
727
  }
728
 
 
729
  await hashing_kv.upsert({cache_data.mode: mode_cache})
730
 
731
 
 
633
  prompt,
634
  mode="default",
635
  cache_type=None,
 
636
  ):
637
  """Generic cache handling function"""
638
+ if hashing_kv is None:
 
 
639
  return None, None, None, None
640
 
641
+ if mode != "default": # handle cache for all type of query
642
+ if not hashing_kv.global_config.get("enable_llm_cache"):
643
+ return None, None, None, None
644
+
645
  # Get embedding cache configuration
646
  embedding_cache_config = hashing_kv.global_config.get(
647
  "embedding_cache_config",
 
651
  use_llm_check = embedding_cache_config.get("use_llm_check", False)
652
 
653
  quantized = min_val = max_val = None
654
+ if is_embedding_cache_enabled: # Use embedding simularity to match cache
 
655
  current_embedding = await hashing_kv.embedding_func([prompt])
656
  llm_model_func = hashing_kv.global_config.get("llm_model_func")
657
  quantized, min_val, max_val = quantize_embedding(current_embedding[0])
 
666
  cache_type=cache_type,
667
  )
668
  if best_cached_response is not None:
669
+ logger.debug(f"Embedding cached hit(mode:{mode} type:{cache_type})")
670
  return best_cached_response, None, None, None
671
  else:
672
  # if caching keyword embedding is enabled, return the quantized embedding for saving it latter
673
+ logger.debug(f"Embedding cached missed(mode:{mode} type:{cache_type})")
674
  return None, quantized, min_val, max_val
675
 
676
+ else: # handle cache for entity extraction
677
+ if not hashing_kv.global_config.get("enable_llm_cache_for_entity_extract"):
678
+ return None, None, None, None
679
+
680
+ # Here is the conditions of code reaching this point:
681
+ # 1. All query mode: enable_llm_cache is True and embedding simularity is not enabled
682
+ # 2. Entity extract: enable_llm_cache_for_entity_extract is True
683
  if exists_func(hashing_kv, "get_by_mode_and_id"):
684
  mode_cache = await hashing_kv.get_by_mode_and_id(mode, args_hash) or {}
685
  else:
686
  mode_cache = await hashing_kv.get_by_id(mode) or {}
687
  if args_hash in mode_cache:
688
+ logger.debug(f"Non-embedding cached hit(mode:{mode} type:{cache_type})")
689
  return mode_cache[args_hash]["return"], None, None, None
690
 
691
+ logger.debug(f"Non-embedding cached missed(mode:{mode} type:{cache_type})")
692
  return None, None, None, None
693
 
694
 
 
705
 
706
 
707
  async def save_to_cache(hashing_kv, cache_data: CacheData):
708
+ """Save data to cache, with improved handling for streaming responses and duplicate content.
709
+
710
+ Args:
711
+ hashing_kv: The key-value storage for caching
712
+ cache_data: The cache data to save
713
+ """
714
+ # Skip if storage is None or content is a streaming response
715
+ if hashing_kv is None or not cache_data.content:
716
  return
717
 
718
+ # If content is a streaming response, don't cache it
719
+ if hasattr(cache_data.content, "__aiter__"):
720
+ logger.debug("Streaming response detected, skipping cache")
721
+ return
722
+
723
+ # Get existing cache data
724
  if exists_func(hashing_kv, "get_by_mode_and_id"):
725
  mode_cache = (
726
  await hashing_kv.get_by_mode_and_id(cache_data.mode, cache_data.args_hash)
 
729
  else:
730
  mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
731
 
732
+ # Check if we already have identical content cached
733
+ if cache_data.args_hash in mode_cache:
734
+ existing_content = mode_cache[cache_data.args_hash].get("return")
735
+ if existing_content == cache_data.content:
736
+ logger.info(
737
+ f"Cache content unchanged for {cache_data.args_hash}, skipping update"
738
+ )
739
+ return
740
+
741
+ # Update cache with new content
742
  mode_cache[cache_data.args_hash] = {
743
  "return": cache_data.content,
744
  "cache_type": cache_data.cache_type,
 
753
  "original_prompt": cache_data.prompt,
754
  }
755
 
756
+ # Only upsert if there's actual new content
757
  await hashing_kv.upsert({cache_data.mode: mode_cache})
758
 
759