yangdx commited on
Commit
09ffe28
·
1 Parent(s): 8050d31

Optimize NetworkX subgraph query

Browse files
Files changed (1) hide show
  1. lightrag/kg/networkx_impl.py +38 -97
lightrag/kg/networkx_impl.py CHANGED
@@ -259,118 +259,59 @@ class NetworkXStorage(BaseGraphStorage):
259
  self,
260
  node_label: str,
261
  max_depth: int = 3,
262
- min_degree: int = 0,
263
- inclusive: bool = False,
264
  ) -> KnowledgeGraph:
265
  """
266
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
267
- Maximum number of nodes is constrained by the environment variable `MAX_GRAPH_NODES` (default: 1000).
268
- When reducing the number of nodes, the prioritization criteria are as follows:
269
- 1. min_degree does not affect nodes directly connected to the matching nodes
270
- 2. Label matching nodes take precedence
271
- 3. Followed by nodes directly connected to the matching nodes
272
- 4. Finally, the degree of the nodes
273
 
274
  Args:
275
- node_label: Label of the starting node
276
- max_depth: Maximum depth of the subgraph
277
- min_degree: Minimum degree of nodes to include. Defaults to 0
278
- inclusive: Do an inclusive search if true
279
 
280
  Returns:
281
  KnowledgeGraph object containing nodes and edges
282
  """
283
- result = KnowledgeGraph()
284
- seen_nodes = set()
285
- seen_edges = set()
286
-
287
  graph = await self._get_graph()
288
 
289
- # Initialize sets for start nodes and direct connected nodes
290
- start_nodes = set()
291
- direct_connected_nodes = set()
292
-
293
  # Handle special case for "*" label
294
  if node_label == "*":
295
- # For "*", return the entire graph including all nodes and edges
296
- subgraph = (
297
- graph.copy()
298
- ) # Create a copy to avoid modifying the original graph
 
 
 
299
  else:
300
- # Find nodes with matching node id based on search_mode
301
- nodes_to_explore = []
302
- for n, attr in graph.nodes(data=True):
303
- node_str = str(n)
304
- if not inclusive:
305
- if node_label == node_str: # Use exact matching
306
- nodes_to_explore.append(n)
307
- else: # inclusive mode
308
- if node_label in node_str: # Use partial matching
309
- nodes_to_explore.append(n)
310
-
311
- if not nodes_to_explore:
312
- logger.warning(f"No nodes found with label {node_label}")
313
- return result
314
-
315
- # Get subgraph using ego_graph from all matching nodes
316
- combined_subgraph = nx.Graph()
317
- for start_node in nodes_to_explore:
318
- node_subgraph = nx.ego_graph(graph, start_node, radius=max_depth)
319
- combined_subgraph = nx.compose(combined_subgraph, node_subgraph)
320
-
321
- # Get start nodes and direct connected nodes
322
- if nodes_to_explore:
323
- start_nodes = set(nodes_to_explore)
324
- # Get nodes directly connected to all start nodes
325
- for start_node in start_nodes:
326
- direct_connected_nodes.update(
327
- combined_subgraph.neighbors(start_node)
328
- )
329
-
330
- # Remove start nodes from directly connected nodes (avoid duplicates)
331
- direct_connected_nodes -= start_nodes
332
-
333
- subgraph = combined_subgraph
334
-
335
- # Filter nodes based on min_degree, but keep start nodes and direct connected nodes
336
- if min_degree > 0:
337
- nodes_to_keep = [
338
- node
339
- for node, degree in subgraph.degree()
340
- if node in start_nodes
341
- or node in direct_connected_nodes
342
- or degree >= min_degree
343
- ]
344
- subgraph = subgraph.subgraph(nodes_to_keep)
345
-
346
- # Check if number of nodes exceeds max_graph_nodes
347
- if len(subgraph.nodes()) > MAX_GRAPH_NODES:
348
- origin_nodes = len(subgraph.nodes())
349
- node_degrees = dict(subgraph.degree())
350
-
351
- def priority_key(node_item):
352
- node, degree = node_item
353
- # Priority order: start(2) > directly connected(1) > other nodes(0)
354
- if node in start_nodes:
355
- priority = 2
356
- elif node in direct_connected_nodes:
357
- priority = 1
358
- else:
359
- priority = 0
360
- return (priority, degree)
361
-
362
- # Sort by priority and degree and select top MAX_GRAPH_NODES nodes
363
- top_nodes = sorted(node_degrees.items(), key=priority_key, reverse=True)[
364
- :MAX_GRAPH_NODES
365
- ]
366
- top_node_ids = [node[0] for node in top_nodes]
367
- # Create new subgraph and keep nodes only with most degree
368
- subgraph = subgraph.subgraph(top_node_ids)
369
- logger.info(
370
- f"Reduced graph from {origin_nodes} nodes to {MAX_GRAPH_NODES} nodes (depth={max_depth})"
371
- )
372
 
373
  # Add nodes to result
 
 
 
374
  for node in subgraph.nodes():
375
  if str(node) in seen_nodes:
376
  continue
@@ -398,7 +339,7 @@ class NetworkXStorage(BaseGraphStorage):
398
  for edge in subgraph.edges():
399
  source, target = edge
400
  # Esure unique edge_id for undirect graph
401
- if source > target:
402
  source, target = target, source
403
  edge_id = f"{source}-{target}"
404
  if edge_id in seen_edges:
 
259
  self,
260
  node_label: str,
261
  max_depth: int = 3,
262
+ max_nodes: int = MAX_GRAPH_NODES,
 
263
  ) -> KnowledgeGraph:
264
  """
265
  Retrieve a connected subgraph of nodes where the label includes the specified `node_label`.
 
 
 
 
 
 
266
 
267
  Args:
268
+ node_label: Label of the starting node,* means all nodes
269
+ max_depth: Maximum depth of the subgraph, Defaults to 3
270
+ max_nodes: Maxiumu nodes to return by BFS, Defaults to 1000
 
271
 
272
  Returns:
273
  KnowledgeGraph object containing nodes and edges
274
  """
 
 
 
 
275
  graph = await self._get_graph()
276
 
 
 
 
 
277
  # Handle special case for "*" label
278
  if node_label == "*":
279
+ # Get degrees of all nodes
280
+ degrees = dict(graph.degree())
281
+ # Sort nodes by degree in descending order and take top max_nodes
282
+ sorted_nodes = sorted(degrees.items(), key=lambda x: x[1], reverse=True)
283
+ limited_nodes = [node for node, _ in sorted_nodes[:max_nodes]]
284
+ # Create subgraph with the highest degree nodes
285
+ subgraph = graph.subgraph(limited_nodes)
286
  else:
287
+ # Check if node exists
288
+ if node_label not in graph:
289
+ logger.warning(f"Node {node_label} not found in the graph")
290
+ return KnowledgeGraph() # Return empty graph
291
+
292
+ # Use BFS to get nodes
293
+ bfs_nodes = []
294
+ visited = set()
295
+ queue = [node_label]
296
+
297
+ # Breadth-first search
298
+ while queue and len(bfs_nodes) < max_nodes:
299
+ current = queue.pop(0)
300
+ if current not in visited:
301
+ visited.add(current)
302
+ bfs_nodes.append(current)
303
+
304
+ # Add neighbor nodes to queue
305
+ neighbors = list(graph.neighbors(current))
306
+ queue.extend([n for n in neighbors if n not in visited])
307
+
308
+ # Create subgraph with BFS discovered nodes
309
+ subgraph = graph.subgraph(bfs_nodes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
  # Add nodes to result
312
+ result = KnowledgeGraph()
313
+ seen_nodes = set()
314
+ seen_edges = set()
315
  for node in subgraph.nodes():
316
  if str(node) in seen_nodes:
317
  continue
 
339
  for edge in subgraph.edges():
340
  source, target = edge
341
  # Esure unique edge_id for undirect graph
342
+ if str(source) > str(target):
343
  source, target = target, source
344
  edge_id = f"{source}-{target}"
345
  if edge_id in seen_edges: