Ken Chen commited on
Commit
465a442
·
1 Parent(s): e6707c3

Implement get_nodes_by_chunk_ids and get_edges_by_chunk_ids,

Browse files
lightrag/kg/mongo_impl.py CHANGED
@@ -17,6 +17,8 @@ from ..base import (
17
  from ..namespace import NameSpace, is_namespace
18
  from ..utils import logger, compute_mdhash_id
19
  from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
 
 
20
  import pipmaster as pm
21
 
22
  if not pm.is_installed("pymongo"):
@@ -353,33 +355,33 @@ class MongoGraphStorage(BaseGraphStorage):
353
  self.collection = None
354
  self.edge_collection = None
355
 
356
- #
357
- # -------------------------------------------------------------------------
358
- # HELPER: $graphLookup pipeline
359
- # -------------------------------------------------------------------------
360
- #
361
 
362
- # Sample entity_relation document
363
  # {
364
  # "_id" : "CompanyA",
365
- # "created_at" : 1749904575,
366
- # "description" : "A major technology company",
367
- # "edges" : [
368
- # {
369
- # "target" : "ProductX",
370
- # "relation": "Develops", // To distinguish multiple same-target relations
371
- # "weight" : Double("1"),
372
- # "description" : "CompanyA develops ProductX",
373
- # "keywords" : "develop, produce",
374
- # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
375
- # "file_path" : "custom_kg",
376
- # "created_at" : 1749904575
377
- # }
378
- # ],
379
  # "entity_id" : "CompanyA",
380
  # "entity_type" : "Organization",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  # "file_path" : "custom_kg",
382
- # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec"
383
  # }
384
 
385
  #
@@ -567,6 +569,45 @@ class MongoGraphStorage(BaseGraphStorage):
567
 
568
  return result
569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  #
571
  # -------------------------------------------------------------------------
572
  # UPSERTS
@@ -578,6 +619,11 @@ class MongoGraphStorage(BaseGraphStorage):
578
  Insert or update a node document.
579
  """
580
  update_doc = {"$set": {**node_data}}
 
 
 
 
 
581
  await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
582
 
583
  async def upsert_edge(
@@ -590,9 +636,15 @@ class MongoGraphStorage(BaseGraphStorage):
590
  # Ensure source node exists
591
  await self.upsert_node(source_node_id, {})
592
 
 
 
 
 
 
 
593
  await self.edge_collection.update_one(
594
  {"source_node_id": source_node_id, "target_node_id": target_node_id},
595
- {"$set": edge_data},
596
  upsert=True,
597
  )
598
 
@@ -789,14 +841,16 @@ class MongoGraphStorage(BaseGraphStorage):
789
  if not edges:
790
  return
791
 
792
- await self.edge_collection.delete_many(
793
- {
794
- "$or": [
795
- {"source_node_id": source_id, "target_node_id": target_id}
796
- for source_id, target_id in edges
797
- ]
798
- }
799
- )
 
 
800
 
801
  logger.debug(f"Successfully deleted edges: {edges}")
802
 
 
17
  from ..namespace import NameSpace, is_namespace
18
  from ..utils import logger, compute_mdhash_id
19
  from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
20
+ from ..constants import GRAPH_FIELD_SEP
21
+
22
  import pipmaster as pm
23
 
24
  if not pm.is_installed("pymongo"):
 
355
  self.collection = None
356
  self.edge_collection = None
357
 
358
+ # Sample entity document
359
+ # "source_ids" is Array representation of "source_id" split by GRAPH_FIELD_SEP
 
 
 
360
 
 
361
  # {
362
  # "_id" : "CompanyA",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
  # "entity_id" : "CompanyA",
364
  # "entity_type" : "Organization",
365
+ # "description" : "A major technology company",
366
+ # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
367
+ # "source_ids": ["chunk-eeec0036b909839e8ec4fa150c939eec"],
368
+ # "file_path" : "custom_kg",
369
+ # "created_at" : 1749904575
370
+ # }
371
+
372
+ # Sample relation document
373
+ # {
374
+ # "_id" : ObjectId("6856ac6e7c6bad9b5470b678"), // MongoDB build-in ObjectId
375
+ # "description" : "CompanyA develops ProductX",
376
+ # "source_node_id" : "CompanyA",
377
+ # "target_node_id" : "ProductX",
378
+ # "relationship": "Develops", // To distinguish multiple same-target relations
379
+ # "weight" : Double("1"),
380
+ # "keywords" : "develop, produce",
381
+ # "source_id" : "chunk-eeec0036b909839e8ec4fa150c939eec",
382
+ # "source_ids": ["chunk-eeec0036b909839e8ec4fa150c939eec"],
383
  # "file_path" : "custom_kg",
384
+ # "created_at" : 1749904575
385
  # }
386
 
387
  #
 
569
 
570
  return result
571
 
572
+ async def get_nodes_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
573
+ """Get all nodes that are associated with the given chunk_ids.
574
+
575
+ Args:
576
+ chunk_ids (list[str]): A list of chunk IDs to find associated nodes for.
577
+
578
+ Returns:
579
+ list[dict]: A list of nodes, where each node is a dictionary of its properties.
580
+ An empty list if no matching nodes are found.
581
+ """
582
+ if not chunk_ids:
583
+ return []
584
+
585
+ cursor = self.collection.find({"source_ids": {"$in": chunk_ids}})
586
+ return [doc async for doc in cursor]
587
+
588
+ async def get_edges_by_chunk_ids(self, chunk_ids: list[str]) -> list[dict]:
589
+ """Get all edges that are associated with the given chunk_ids.
590
+
591
+ Args:
592
+ chunk_ids (list[str]): A list of chunk IDs to find associated edges for.
593
+
594
+ Returns:
595
+ list[dict]: A list of edges, where each edge is a dictionary of its properties.
596
+ An empty list if no matching edges are found.
597
+ """
598
+ if not chunk_ids:
599
+ return []
600
+
601
+ cursor = self.edge_collection.find({"source_ids": {"$in": chunk_ids}})
602
+
603
+ edges = []
604
+ async for edge in cursor:
605
+ edge["source"] = edge["source_node_id"]
606
+ edge["target"] = edge["target_node_id"]
607
+ edges.append(edge)
608
+
609
+ return edges
610
+
611
  #
612
  # -------------------------------------------------------------------------
613
  # UPSERTS
 
619
  Insert or update a node document.
620
  """
621
  update_doc = {"$set": {**node_data}}
622
+ if node_data.get("source_id", ""):
623
+ update_doc["$set"]["source_ids"] = node_data["source_id"].split(
624
+ GRAPH_FIELD_SEP
625
+ )
626
+
627
  await self.collection.update_one({"_id": node_id}, update_doc, upsert=True)
628
 
629
  async def upsert_edge(
 
636
  # Ensure source node exists
637
  await self.upsert_node(source_node_id, {})
638
 
639
+ update_doc = {"$set": edge_data}
640
+ if edge_data.get("source_id", ""):
641
+ update_doc["$set"]["source_ids"] = edge_data["source_id"].split(
642
+ GRAPH_FIELD_SEP
643
+ )
644
+
645
  await self.edge_collection.update_one(
646
  {"source_node_id": source_node_id, "target_node_id": target_node_id},
647
+ update_doc,
648
  upsert=True,
649
  )
650
 
 
841
  if not edges:
842
  return
843
 
844
+ all_edge_pairs = []
845
+ for source_id, target_id in edges:
846
+ all_edge_pairs.append(
847
+ {"source_node_id": source_id, "target_node_id": target_id}
848
+ )
849
+ all_edge_pairs.append(
850
+ {"source_node_id": target_id, "target_node_id": source_id}
851
+ )
852
+
853
+ await self.edge_collection.delete_many({"$or": all_edge_pairs})
854
 
855
  logger.debug(f"Successfully deleted edges: {edges}")
856
 
tests/test_graph_storage.py CHANGED
@@ -30,6 +30,7 @@ from lightrag.kg import (
30
  verify_storage_implementation,
31
  )
32
  from lightrag.kg.shared_storage import initialize_share_data
 
33
 
34
 
35
  # 模拟的嵌入函数,返回随机向量
@@ -437,6 +438,9 @@ async def test_graph_batch_operations(storage):
437
  5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边
438
  """
439
  try:
 
 
 
440
  # 1. 插入测试数据
441
  # 插入节点1: 人工智能
442
  node1_id = "人工智能"
@@ -445,6 +449,7 @@ async def test_graph_batch_operations(storage):
445
  "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
446
  "keywords": "AI,机器学习,深度学习",
447
  "entity_type": "技术领域",
 
448
  }
449
  print(f"插入节点1: {node1_id}")
450
  await storage.upsert_node(node1_id, node1_data)
@@ -456,6 +461,7 @@ async def test_graph_batch_operations(storage):
456
  "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
457
  "keywords": "监督学习,无监督学习,强化学习",
458
  "entity_type": "技术领域",
 
459
  }
460
  print(f"插入节点2: {node2_id}")
461
  await storage.upsert_node(node2_id, node2_data)
@@ -467,6 +473,7 @@ async def test_graph_batch_operations(storage):
467
  "description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
468
  "keywords": "神经网络,CNN,RNN",
469
  "entity_type": "技术领域",
 
470
  }
471
  print(f"插入节点3: {node3_id}")
472
  await storage.upsert_node(node3_id, node3_data)
@@ -498,6 +505,7 @@ async def test_graph_batch_operations(storage):
498
  "relationship": "包含",
499
  "weight": 1.0,
500
  "description": "人工智能领域包含机器学习这个子领域",
 
501
  }
502
  print(f"插入边1: {node1_id} -> {node2_id}")
503
  await storage.upsert_edge(node1_id, node2_id, edge1_data)
@@ -507,6 +515,7 @@ async def test_graph_batch_operations(storage):
507
  "relationship": "包含",
508
  "weight": 1.0,
509
  "description": "机器学习领域包含深度学习这个子领域",
 
510
  }
511
  print(f"插入边2: {node2_id} -> {node3_id}")
512
  await storage.upsert_edge(node2_id, node3_id, edge2_data)
@@ -516,6 +525,7 @@ async def test_graph_batch_operations(storage):
516
  "relationship": "包含",
517
  "weight": 1.0,
518
  "description": "人工智能领域包含自然语言处理这个子领域",
 
519
  }
520
  print(f"插入边3: {node1_id} -> {node4_id}")
521
  await storage.upsert_edge(node1_id, node4_id, edge3_data)
@@ -748,6 +758,76 @@ async def test_graph_batch_operations(storage):
748
 
749
  print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)")
750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
  print("\n批量操作测试完成")
752
  return True
753
 
 
30
  verify_storage_implementation,
31
  )
32
  from lightrag.kg.shared_storage import initialize_share_data
33
+ from lightrag.constants import GRAPH_FIELD_SEP
34
 
35
 
36
  # 模拟的嵌入函数,返回随机向量
 
438
  5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边
439
  """
440
  try:
441
+ chunk1_id = "1"
442
+ chunk2_id = "2"
443
+ chunk3_id = "3"
444
  # 1. 插入测试数据
445
  # 插入节点1: 人工智能
446
  node1_id = "人工智能"
 
449
  "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
450
  "keywords": "AI,机器学习,深度学习",
451
  "entity_type": "技术领域",
452
+ "source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]),
453
  }
454
  print(f"插入节点1: {node1_id}")
455
  await storage.upsert_node(node1_id, node1_data)
 
461
  "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
462
  "keywords": "监督学习,无监督学习,强化学习",
463
  "entity_type": "技术领域",
464
+ "source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]),
465
  }
466
  print(f"插入节点2: {node2_id}")
467
  await storage.upsert_node(node2_id, node2_data)
 
473
  "description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
474
  "keywords": "神经网络,CNN,RNN",
475
  "entity_type": "技术领域",
476
+ "source_id": GRAPH_FIELD_SEP.join([chunk3_id]),
477
  }
478
  print(f"插入节点3: {node3_id}")
479
  await storage.upsert_node(node3_id, node3_data)
 
505
  "relationship": "包含",
506
  "weight": 1.0,
507
  "description": "人工智能领域包含机器学习这个子领域",
508
+ "source_id": GRAPH_FIELD_SEP.join([chunk1_id, chunk2_id]),
509
  }
510
  print(f"插入边1: {node1_id} -> {node2_id}")
511
  await storage.upsert_edge(node1_id, node2_id, edge1_data)
 
515
  "relationship": "包含",
516
  "weight": 1.0,
517
  "description": "机器学习领域包含深度学习这个子领域",
518
+ "source_id": GRAPH_FIELD_SEP.join([chunk2_id, chunk3_id]),
519
  }
520
  print(f"插入边2: {node2_id} -> {node3_id}")
521
  await storage.upsert_edge(node2_id, node3_id, edge2_data)
 
525
  "relationship": "包含",
526
  "weight": 1.0,
527
  "description": "人工智能领域包含自然语言处理这个子领域",
528
+ "source_id": GRAPH_FIELD_SEP.join([chunk3_id]),
529
  }
530
  print(f"插入边3: {node1_id} -> {node4_id}")
531
  await storage.upsert_edge(node1_id, node4_id, edge3_data)
 
758
 
759
  print("无向图特性验证成功:批量获取的节点边包含所有相关的边(无论方向)")
760
 
761
+ # 7. 测试 get_nodes_by_chunk_ids - 批量根据 chunk_ids 获取多个节点
762
+ print("== 测试 get_nodes_by_chunk_ids")
763
+
764
+ print("== 测试单个 chunk_id,匹配多个节点")
765
+ nodes = await storage.get_nodes_by_chunk_ids([chunk2_id])
766
+ assert len(nodes) == 2, f"{chunk1_id} 应有2个节点,实际有 {len(nodes)} 个"
767
+
768
+ has_node1 = any(node["entity_id"] == node1_id for node in nodes)
769
+ has_node2 = any(node["entity_id"] == node2_id for node in nodes)
770
+
771
+ assert has_node1, f"节点 {node1_id} 应在返回结果中"
772
+ assert has_node2, f"节点 {node2_id} 应在返回结果中"
773
+
774
+ print("== 测试多个 chunk_id,部分匹配多个节点")
775
+ nodes = await storage.get_nodes_by_chunk_ids([chunk2_id, chunk3_id])
776
+ assert (
777
+ len(nodes) == 3
778
+ ), f"{chunk2_id}, {chunk3_id} 应有3个节点,实际有 {len(nodes)} 个"
779
+
780
+ has_node1 = any(node["entity_id"] == node1_id for node in nodes)
781
+ has_node2 = any(node["entity_id"] == node2_id for node in nodes)
782
+ has_node3 = any(node["entity_id"] == node3_id for node in nodes)
783
+
784
+ assert has_node1, f"节点 {node1_id} 应在返回结果中"
785
+ assert has_node2, f"节点 {node2_id} 应在返回结果中"
786
+ assert has_node3, f"节点 {node3_id} 应在返回结果中"
787
+
788
+ # 8. 测试 get_edges_by_chunk_ids - 批量根据 chunk_ids 获取多条边
789
+ print("== 测试 get_edges_by_chunk_ids")
790
+
791
+ print("== 测试单个 chunk_id,匹配多条边")
792
+ edges = await storage.get_edges_by_chunk_ids([chunk2_id])
793
+ assert len(edges) == 2, f"{chunk2_id} 应有2条边,实际有 {len(edges)} 条"
794
+
795
+ has_edge_node1_node2 = any(
796
+ edge["source"] == node1_id and edge["target"] == node2_id for edge in edges
797
+ )
798
+ has_edge_node2_node3 = any(
799
+ edge["source"] == node2_id and edge["target"] == node3_id for edge in edges
800
+ )
801
+
802
+ assert has_edge_node1_node2, f"{chunk2_id} 应包含 {node1_id} 到 {node2_id} 的边"
803
+ assert has_edge_node2_node3, f"{chunk2_id} 应包含 {node2_id} 到 {node3_id} 的边"
804
+
805
+ print("== 测试多个 chunk_id,部分匹配多条边")
806
+ edges = await storage.get_edges_by_chunk_ids([chunk2_id, chunk3_id])
807
+ assert (
808
+ len(edges) == 3
809
+ ), f"{chunk2_id}, {chunk3_id} 应有3条边,实际有 {len(edges)} 条"
810
+
811
+ has_edge_node1_node2 = any(
812
+ edge["source"] == node1_id and edge["target"] == node2_id for edge in edges
813
+ )
814
+ has_edge_node2_node3 = any(
815
+ edge["source"] == node2_id and edge["target"] == node3_id for edge in edges
816
+ )
817
+ has_edge_node1_node4 = any(
818
+ edge["source"] == node1_id and edge["target"] == node4_id for edge in edges
819
+ )
820
+
821
+ assert (
822
+ has_edge_node1_node2
823
+ ), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id} 到 {node2_id} 的边"
824
+ assert (
825
+ has_edge_node2_node3
826
+ ), f"{chunk2_id}, {chunk3_id} 应包含 {node2_id} 到 {node3_id} 的边"
827
+ assert (
828
+ has_edge_node1_node4
829
+ ), f"{chunk2_id}, {chunk3_id} 应包含 {node1_id} 到 {node4_id} 的边"
830
+
831
  print("\n批量操作测试完成")
832
  return True
833