yangdx commited on
Commit
fe876b6
·
1 Parent(s): ec8fba9

Add batch query unit test for grap storage

Browse files
Files changed (1) hide show
  1. tests/test_graph_storage.py +211 -3
tests/test_graph_storage.py CHANGED
@@ -377,6 +377,207 @@ async def test_graph_advanced(storage):
377
  return False
378
 
379
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
  async def main():
381
  """主函数"""
382
  # 显示程序标题
@@ -411,21 +612,28 @@ async def main():
411
  ASCIIColors.yellow("\n请选择测试类型:")
412
  ASCIIColors.white("1. 基本测试 (节点和边的插入、读取)")
413
  ASCIIColors.white("2. 高级测试 (度数、标签、知识图谱、删除操作等)")
414
- ASCIIColors.white("3. 全部测试")
 
415
 
416
- choice = input("\n请输入选项 (1/2/3): ")
417
 
418
  if choice == "1":
419
  await test_graph_basic(storage)
420
  elif choice == "2":
421
  await test_graph_advanced(storage)
422
  elif choice == "3":
 
 
423
  ASCIIColors.cyan("\n=== 开始基本测试 ===")
424
  basic_result = await test_graph_basic(storage)
425
 
426
  if basic_result:
427
  ASCIIColors.cyan("\n=== 开始高级测试 ===")
428
- await test_graph_advanced(storage)
 
 
 
 
429
  else:
430
  ASCIIColors.red("无效的选项")
431
 
 
377
  return False
378
 
379
 
380
+ async def test_graph_batch_operations(storage):
381
+ """
382
+ 测试图数据库的批量操作:
383
+ 1. 使用 get_nodes_batch 批量获取多个节点的属性
384
+ 2. 使用 node_degrees_batch 批量获取多个节点的度数
385
+ 3. 使用 edge_degrees_batch 批量获取多个边的度数
386
+ 4. 使用 get_edges_batch 批量获取多个边的属性
387
+ 5. 使用 get_nodes_edges_batch 批量获取多个节点的所有边
388
+ """
389
+ try:
390
+ # 清理之前的测试数据
391
+ print("清理之前的测试数据...\n")
392
+ await storage.drop()
393
+
394
+ # 1. 插入测试数据
395
+ # 插入节点1: 人工智能
396
+ node1_id = "人工智能"
397
+ node1_data = {
398
+ "entity_id": node1_id,
399
+ "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
400
+ "keywords": "AI,机器学习,深度学习",
401
+ "entity_type": "技术领域",
402
+ }
403
+ print(f"插入节点1: {node1_id}")
404
+ await storage.upsert_node(node1_id, node1_data)
405
+
406
+ # 插入节点2: 机器学习
407
+ node2_id = "机器学习"
408
+ node2_data = {
409
+ "entity_id": node2_id,
410
+ "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
411
+ "keywords": "监督学习,无监督学习,强化学习",
412
+ "entity_type": "技术领域",
413
+ }
414
+ print(f"插入节点2: {node2_id}")
415
+ await storage.upsert_node(node2_id, node2_data)
416
+
417
+ # 插入节点3: 深度学习
418
+ node3_id = "深度学习"
419
+ node3_data = {
420
+ "entity_id": node3_id,
421
+ "description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
422
+ "keywords": "神经网络,CNN,RNN",
423
+ "entity_type": "技术领域",
424
+ }
425
+ print(f"插入节点3: {node3_id}")
426
+ await storage.upsert_node(node3_id, node3_data)
427
+
428
+ # 插入节点4: 自然语言处理
429
+ node4_id = "自然语言处理"
430
+ node4_data = {
431
+ "entity_id": node4_id,
432
+ "description": "自然语言处理是人工智能的一个分支,专注于使计算机理解和处理人类语言。",
433
+ "keywords": "NLP,文本分析,语言模型",
434
+ "entity_type": "技术领域",
435
+ }
436
+ print(f"插入节点4: {node4_id}")
437
+ await storage.upsert_node(node4_id, node4_data)
438
+
439
+ # 插入节点5: 计算机视觉
440
+ node5_id = "计算机视觉"
441
+ node5_data = {
442
+ "entity_id": node5_id,
443
+ "description": "计算机视觉是人工智能的一个分支,专注于使计算机能够从图像或视频中获取信息。",
444
+ "keywords": "CV,图像识别,目标检测",
445
+ "entity_type": "技术领域",
446
+ }
447
+ print(f"插入节点5: {node5_id}")
448
+ await storage.upsert_node(node5_id, node5_data)
449
+
450
+ # 插入边1: 人工智能 -> 机器学习
451
+ edge1_data = {
452
+ "relationship": "包含",
453
+ "weight": 1.0,
454
+ "description": "人工智能领域包含机器学习这个子领域",
455
+ }
456
+ print(f"插入边1: {node1_id} -> {node2_id}")
457
+ await storage.upsert_edge(node1_id, node2_id, edge1_data)
458
+
459
+ # 插入边2: 机器学习 -> 深度学习
460
+ edge2_data = {
461
+ "relationship": "包含",
462
+ "weight": 1.0,
463
+ "description": "机器学习领域包含深度学习这个子领域",
464
+ }
465
+ print(f"插入边2: {node2_id} -> {node3_id}")
466
+ await storage.upsert_edge(node2_id, node3_id, edge2_data)
467
+
468
+ # 插入边3: 人工智能 -> 自然语言处理
469
+ edge3_data = {
470
+ "relationship": "包含",
471
+ "weight": 1.0,
472
+ "description": "人工智能领域包含自然语言处理这个子领域",
473
+ }
474
+ print(f"插入边3: {node1_id} -> {node4_id}")
475
+ await storage.upsert_edge(node1_id, node4_id, edge3_data)
476
+
477
+ # 插入边4: 人工智能 -> 计算机视觉
478
+ edge4_data = {
479
+ "relationship": "包含",
480
+ "weight": 1.0,
481
+ "description": "人工智能领域包含计算机视觉这个子领域",
482
+ }
483
+ print(f"插入边4: {node1_id} -> {node5_id}")
484
+ await storage.upsert_edge(node1_id, node5_id, edge4_data)
485
+
486
+ # 插入边5: 深度学习 -> 自然语言处理
487
+ edge5_data = {
488
+ "relationship": "应用于",
489
+ "weight": 0.8,
490
+ "description": "深度学习技术应用于自然语言处理领域",
491
+ }
492
+ print(f"插入边5: {node3_id} -> {node4_id}")
493
+ await storage.upsert_edge(node3_id, node4_id, edge5_data)
494
+
495
+ # 插入边6: 深度��习 -> 计算机视觉
496
+ edge6_data = {
497
+ "relationship": "应用于",
498
+ "weight": 0.8,
499
+ "description": "深度学习技术应用于计算机视觉领域",
500
+ }
501
+ print(f"插入边6: {node3_id} -> {node5_id}")
502
+ await storage.upsert_edge(node3_id, node5_id, edge6_data)
503
+
504
+ # 2. 测试 get_nodes_batch - 批量获取多个节点的属性
505
+ print("== 测试 get_nodes_batch")
506
+ node_ids = [node1_id, node2_id, node3_id]
507
+ nodes_dict = await storage.get_nodes_batch(node_ids)
508
+ print(f"批量获取节点属性结果: {nodes_dict.keys()}")
509
+ assert len(nodes_dict) == 3, f"应返回3个节点,实际返回 {len(nodes_dict)} 个"
510
+ assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
511
+ assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
512
+ assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
513
+ assert nodes_dict[node1_id]["description"] == node1_data["description"], f"{node1_id} 描述不匹配"
514
+ assert nodes_dict[node2_id]["description"] == node2_data["description"], f"{node2_id} 描述不匹配"
515
+ assert nodes_dict[node3_id]["description"] == node3_data["description"], f"{node3_id} 描述不匹配"
516
+
517
+ # 3. 测试 node_degrees_batch - 批量获取多个节点的度数
518
+ print("== 测试 node_degrees_batch")
519
+ node_degrees = await storage.node_degrees_batch(node_ids)
520
+ print(f"批量获取节点度数结果: {node_degrees}")
521
+ assert len(node_degrees) == 3, f"应返回3个节点的度数,实际返回 {len(node_degrees)} 个"
522
+ assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
523
+ assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
524
+ assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
525
+ assert node_degrees[node1_id] == 3, f"{node1_id} 度数应为3,实际为 {node_degrees[node1_id]}"
526
+ assert node_degrees[node2_id] == 2, f"{node2_id} 度数应为2,实际为 {node_degrees[node2_id]}"
527
+ assert node_degrees[node3_id] == 3, f"{node3_id} 度数应为3,实际为 {node_degrees[node3_id]}"
528
+
529
+ # 4. 测试 edge_degrees_batch - 批量获取多个边的度数
530
+ print("== 测试 edge_degrees_batch")
531
+ edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
532
+ edge_degrees = await storage.edge_degrees_batch(edges)
533
+ print(f"批量获取边度数结果: {edge_degrees}")
534
+ assert len(edge_degrees) == 3, f"应返回3条边的度数,实际返回 {len(edge_degrees)} 条"
535
+ assert (node1_id, node2_id) in edge_degrees, f"边 {node1_id} -> {node2_id} 应在返回结果中"
536
+ assert (node2_id, node3_id) in edge_degrees, f"边 {node2_id} -> {node3_id} 应在返回结果中"
537
+ assert (node3_id, node4_id) in edge_degrees, f"边 {node3_id} -> {node4_id} 应在返回结果中"
538
+ # 验证边的度数是否正确(源节点度数 + 目标节点度数)
539
+ assert edge_degrees[(node1_id, node2_id)] == 5, f"边 {node1_id} -> {node2_id} 度数应为5,实际为 {edge_degrees[(node1_id, node2_id)]}"
540
+ assert edge_degrees[(node2_id, node3_id)] == 5, f"边 {node2_id} -> {node3_id} 度数应为5,实际为 {edge_degrees[(node2_id, node3_id)]}"
541
+ assert edge_degrees[(node3_id, node4_id)] == 5, f"边 {node3_id} -> {node4_id} 度数应为5,实际为 {edge_degrees[(node3_id, node4_id)]}"
542
+
543
+ # 5. 测试 get_edges_batch - 批量获取多个边的属性
544
+ print("== 测试 get_edges_batch")
545
+ # 将元组列表转换为Neo4j风格的字典列表
546
+ edge_dicts = [{"src": src, "tgt": tgt} for src, tgt in edges]
547
+ edges_dict = await storage.get_edges_batch(edge_dicts)
548
+ print(f"批量获取边属性结果: {edges_dict.keys()}")
549
+ assert len(edges_dict) == 3, f"应返回3条边的属性,实际返回 {len(edges_dict)} 条"
550
+ assert (node1_id, node2_id) in edges_dict, f"边 {node1_id} -> {node2_id} 应在返回结果中"
551
+ assert (node2_id, node3_id) in edges_dict, f"边 {node2_id} -> {node3_id} 应在返回结果中"
552
+ assert (node3_id, node4_id) in edges_dict, f"边 {node3_id} -> {node4_id} 应在返回结果中"
553
+ assert edges_dict[(node1_id, node2_id)]["relationship"] == edge1_data["relationship"], f"边 {node1_id} -> {node2_id} 关系不匹配"
554
+ assert edges_dict[(node2_id, node3_id)]["relationship"] == edge2_data["relationship"], f"边 {node2_id} -> {node3_id} 关系不匹配"
555
+ assert edges_dict[(node3_id, node4_id)]["relationship"] == edge5_data["relationship"], f"边 {node3_id} -> {node4_id} 关系不匹配"
556
+
557
+ # 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
558
+ print("== 测试 get_nodes_edges_batch")
559
+ nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
560
+ print(f"批量获取节点边结果: {nodes_edges.keys()}")
561
+ assert len(nodes_edges) == 2, f"应返回2个节点的边,实际返回 {len(nodes_edges)} 个"
562
+ assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
563
+ assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
564
+ assert len(nodes_edges[node1_id]) == 3, f"{node1_id} 应有3条边,实际有 {len(nodes_edges[node1_id])} 条"
565
+ assert len(nodes_edges[node3_id]) == 3, f"{node3_id} 应有3条边,实际有 {len(nodes_edges[node3_id])} 条"
566
+
567
+ # 7. 清理数据
568
+ print("== 测试 drop")
569
+ result = await storage.drop()
570
+ print(f"清理结果: {result}")
571
+ assert result["status"] == "success", f"清理应成功,实际状态为 {result['status']}"
572
+
573
+ print("\n批量操作测试完成")
574
+ return True
575
+
576
+ except Exception as e:
577
+ ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
578
+ return False
579
+
580
+
581
  async def main():
582
  """主函数"""
583
  # 显示程序标题
 
612
  ASCIIColors.yellow("\n请选择测试类型:")
613
  ASCIIColors.white("1. 基本测试 (节点和边的插入、读取)")
614
  ASCIIColors.white("2. 高级测试 (度数、标签、知识图谱、删除操作等)")
615
+ ASCIIColors.white("3. 批量操作测试 (批量获取节点、边属性和度数等)")
616
+ ASCIIColors.white("4. 全部测试")
617
 
618
+ choice = input("\n请输入选项 (1/2/3/4): ")
619
 
620
  if choice == "1":
621
  await test_graph_basic(storage)
622
  elif choice == "2":
623
  await test_graph_advanced(storage)
624
  elif choice == "3":
625
+ await test_graph_batch_operations(storage)
626
+ elif choice == "4":
627
  ASCIIColors.cyan("\n=== 开始基本测试 ===")
628
  basic_result = await test_graph_basic(storage)
629
 
630
  if basic_result:
631
  ASCIIColors.cyan("\n=== 开始高级测试 ===")
632
+ advanced_result = await test_graph_advanced(storage)
633
+
634
+ if advanced_result:
635
+ ASCIIColors.cyan("\n=== 开始批量操作测试 ===")
636
+ await test_graph_batch_operations(storage)
637
  else:
638
  ASCIIColors.red("无效的选项")
639