gzdaniel commited on
Commit
e6b8d67
·
1 Parent(s): ddc7121

fix: optimize MongoDB aggregation pipeline to prevent memory limit errors

Browse files

- Move $limit operation early in pipeline for "*" queries to reduce memory usage
- Remove memory-intensive $sort operation for large dataset queries
- Add fallback mechanism for memory limit errors with simple query
- Implement additional safety checks to enforce max_nodes limit
- Improve error handling and logging for memory-related issues

Files changed (1) hide show
  1. lightrag/kg/mongo_impl.py +70 -20
lightrag/kg/mongo_impl.py CHANGED
@@ -732,24 +732,25 @@ class MongoGraphStorage(BaseGraphStorage):
732
  node_edges = []
733
 
734
  try:
735
- pipeline = [
736
- {
737
- "$graphLookup": {
738
- "from": self._edge_collection_name,
739
- "startWith": "$_id",
740
- "connectFromField": "target_node_id",
741
- "connectToField": "source_node_id",
742
- "maxDepth": max_depth,
743
- "depthField": "depth",
744
- "as": "connected_edges",
 
 
 
 
 
745
  },
746
- },
747
- {"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
748
- {"$sort": {"edge_count": -1}},
749
- {"$limit": max_nodes},
750
- ]
751
 
752
- if label == "*":
753
  all_node_count = await self.collection.count_documents({})
754
  result.is_truncated = all_node_count > max_nodes
755
  else:
@@ -759,10 +760,28 @@ class MongoGraphStorage(BaseGraphStorage):
759
  logger.warning(f"Starting node with label {label} does not exist!")
760
  return result
761
 
762
- # Add starting node to pipeline
763
- pipeline.insert(0, {"$match": {"_id": label}})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
764
 
765
  cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
 
 
766
  async for doc in cursor:
767
  # Add the start node
768
  node_id = str(doc["_id"])
@@ -786,6 +805,13 @@ class MongoGraphStorage(BaseGraphStorage):
786
  if doc.get("connected_edges", []):
787
  node_edges.extend(doc.get("connected_edges"))
788
 
 
 
 
 
 
 
 
789
  for edge in node_edges:
790
  if (
791
  edge["source_node_id"] not in seen_nodes
@@ -817,11 +843,35 @@ class MongoGraphStorage(BaseGraphStorage):
817
  seen_edges.add(edge_id)
818
 
819
  logger.info(
820
- f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
821
  )
822
 
823
  except PyMongoError as e:
824
- logger.error(f"MongoDB query failed: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
 
826
  return result
827
 
 
732
  node_edges = []
733
 
734
  try:
735
+ # Optimize pipeline to avoid memory issues with large datasets
736
+ if label == "*":
737
+ # For getting all nodes, use a simpler pipeline to avoid memory issues
738
+ pipeline = [
739
+ {"$limit": max_nodes}, # Limit early to reduce memory usage
740
+ {
741
+ "$graphLookup": {
742
+ "from": self._edge_collection_name,
743
+ "startWith": "$_id",
744
+ "connectFromField": "target_node_id",
745
+ "connectToField": "source_node_id",
746
+ "maxDepth": max_depth,
747
+ "depthField": "depth",
748
+ "as": "connected_edges",
749
+ },
750
  },
751
+ ]
 
 
 
 
752
 
753
+ # Check if we need to set truncation flag
754
  all_node_count = await self.collection.count_documents({})
755
  result.is_truncated = all_node_count > max_nodes
756
  else:
 
760
  logger.warning(f"Starting node with label {label} does not exist!")
761
  return result
762
 
763
+ # For specific node queries, use the original pipeline but optimized
764
+ pipeline = [
765
+ {"$match": {"_id": label}},
766
+ {
767
+ "$graphLookup": {
768
+ "from": self._edge_collection_name,
769
+ "startWith": "$_id",
770
+ "connectFromField": "target_node_id",
771
+ "connectToField": "source_node_id",
772
+ "maxDepth": max_depth,
773
+ "depthField": "depth",
774
+ "as": "connected_edges",
775
+ },
776
+ },
777
+ {"$addFields": {"edge_count": {"$size": "$connected_edges"}}},
778
+ {"$sort": {"edge_count": -1}},
779
+ {"$limit": max_nodes},
780
+ ]
781
 
782
  cursor = await self.collection.aggregate(pipeline, allowDiskUse=True)
783
+ nodes_processed = 0
784
+
785
  async for doc in cursor:
786
  # Add the start node
787
  node_id = str(doc["_id"])
 
805
  if doc.get("connected_edges", []):
806
  node_edges.extend(doc.get("connected_edges"))
807
 
808
+ nodes_processed += 1
809
+
810
+ # Additional safety check to prevent memory issues
811
+ if nodes_processed >= max_nodes:
812
+ result.is_truncated = True
813
+ break
814
+
815
  for edge in node_edges:
816
  if (
817
  edge["source_node_id"] not in seen_nodes
 
843
  seen_edges.add(edge_id)
844
 
845
  logger.info(
846
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)} | Truncated: {result.is_truncated}"
847
  )
848
 
849
  except PyMongoError as e:
850
+ # Handle memory limit errors specifically
851
+ if "memory limit" in str(e).lower() or "sort exceeded" in str(e).lower():
852
+ logger.warning(
853
+ f"MongoDB memory limit exceeded, falling back to simple query: {str(e)}"
854
+ )
855
+ # Fallback to a simple query without complex aggregation
856
+ try:
857
+ simple_cursor = self.collection.find({}).limit(max_nodes)
858
+ async for doc in simple_cursor:
859
+ node_id = str(doc["_id"])
860
+ result.nodes.append(
861
+ KnowledgeGraphNode(
862
+ id=node_id,
863
+ labels=[node_id],
864
+ properties={k: v for k, v in doc.items() if k != "_id"},
865
+ )
866
+ )
867
+ result.is_truncated = True
868
+ logger.info(
869
+ f"Fallback query completed | Node count: {len(result.nodes)}"
870
+ )
871
+ except PyMongoError as fallback_error:
872
+ logger.error(f"Fallback query also failed: {str(fallback_error)}")
873
+ else:
874
+ logger.error(f"MongoDB query failed: {str(e)}")
875
 
876
  return result
877