yangdx commited on
Commit
5cd6509
·
1 Parent(s): 0451ee6

Add graph storage unit test

Browse files
Files changed (1) hide show
  1. tests/test_graph_storage.py +411 -0
tests/test_graph_storage.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ 通用图存储测试程序
4
+
5
+ 该程序根据.env中的LIGHTRAG_GRAPH_STORAGE配置选择使用的图存储类型,
6
+ 并对其进行基本操作和高级操作的测试。
7
+
8
+ 支持的图存储类型包括:
9
+ - NetworkXStorage
10
+ - Neo4JStorage
11
+ - PGGraphStorage
12
+ """
13
+
14
+ import asyncio
15
+ import os
16
+ import sys
17
+ import importlib
18
+ import numpy as np
19
+ from dotenv import load_dotenv
20
+ from ascii_colors import ASCIIColors
21
+
22
+ # 添加项目根目录到Python路径
23
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
24
+
25
+ from lightrag.types import KnowledgeGraph
26
+ from lightrag.kg import (
27
+ STORAGE_IMPLEMENTATIONS,
28
+ STORAGE_ENV_REQUIREMENTS,
29
+ STORAGES,
30
+ verify_storage_implementation
31
+ )
32
+ from lightrag.kg.shared_storage import initialize_share_data
33
+
34
+ # 模拟的嵌入函数,返回随机向量
35
+ async def mock_embedding_func(texts):
36
+ return np.random.rand(len(texts), 10) # 返回10维随机向量
37
+
38
+ def check_env_file():
39
+ """
40
+ 检查.env文件是否存在,如果不存在则发出警告
41
+ 返回True表示应该继续执行,False表示应该退出
42
+ """
43
+ if not os.path.exists(".env"):
44
+ warning_msg = "警告: 当前目录中没有找到.env文件,这可能会影响存储配置的加载。"
45
+ ASCIIColors.yellow(warning_msg)
46
+
47
+ # 检查是否在交互式终端中运行
48
+ if sys.stdin.isatty():
49
+ response = input("是否继续执行? (yes/no): ")
50
+ if response.lower() != "yes":
51
+ ASCIIColors.red("测试程序已取消")
52
+ return False
53
+ return True
54
+
55
+ async def initialize_graph_storage():
56
+ """
57
+ 根据环境变量初始化相应的图存储实例
58
+ 返回初始化的存储实例
59
+ """
60
+ # 从环境变量中获取图存储类型
61
+ graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
62
+
63
+ # 验证存储类型是否有效
64
+ try:
65
+ verify_storage_implementation("GRAPH_STORAGE", graph_storage_type)
66
+ except ValueError as e:
67
+ ASCIIColors.red(f"错误: {str(e)}")
68
+ ASCIIColors.yellow(f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}")
69
+ return None
70
+
71
+ # 检查所需的环境变量
72
+ required_env_vars = STORAGE_ENV_REQUIREMENTS.get(graph_storage_type, [])
73
+ missing_env_vars = [var for var in required_env_vars if not os.getenv(var)]
74
+
75
+ if missing_env_vars:
76
+ ASCIIColors.red(f"错误: {graph_storage_type} 需要以下环境变量,但未设置: {', '.join(missing_env_vars)}")
77
+ return None
78
+
79
+ # 动态导入相应的模块
80
+ module_path = STORAGES.get(graph_storage_type)
81
+ if not module_path:
82
+ ASCIIColors.red(f"错误: 未找到 {graph_storage_type} 的模块路径")
83
+ return None
84
+
85
+ try:
86
+ module = importlib.import_module(module_path, package="lightrag")
87
+ storage_class = getattr(module, graph_storage_type)
88
+ except (ImportError, AttributeError) as e:
89
+ ASCIIColors.red(f"错误: 导入 {graph_storage_type} 失败: {str(e)}")
90
+ return None
91
+
92
+ # 初始化存储实例
93
+ global_config = {
94
+ "embedding_batch_num": 10, # 批处理大小
95
+ "vector_db_storage_cls_kwargs": {
96
+ "cosine_better_than_threshold": 0.5 # 余弦相似度阈值
97
+ },
98
+ "working_dir": os.environ.get("WORKING_DIR", "./rag_storage") # 工作目录
99
+ }
100
+
101
+ # 如果使用 NetworkXStorage,需要先初始化 shared_storage
102
+ if graph_storage_type == "NetworkXStorage":
103
+ initialize_share_data() # 使用单进程模式
104
+
105
+ try:
106
+ storage = storage_class(
107
+ namespace="test_graph",
108
+ global_config=global_config,
109
+ embedding_func=mock_embedding_func
110
+ )
111
+
112
+ # 初始化连接
113
+ await storage.initialize()
114
+ return storage
115
+ except Exception as e:
116
+ ASCIIColors.red(f"错误: 初始化 {graph_storage_type} 失败: {str(e)}")
117
+ return None
118
+
119
+ async def test_graph_basic(storage):
120
+ """
121
+ 测试图数据库的基本操作:
122
+ 1. 使用 upsert_node 插入两个节点
123
+ 2. 使用 upsert_edge 插入一条连接两个节点的边
124
+ 3. 使用 get_node 读取一个节点
125
+ 4. 使用 get_edge 读取一条边
126
+ """
127
+ try:
128
+ # 清理之前的测试数据
129
+ print("清理之前的测试数据...")
130
+ await storage.drop()
131
+
132
+ # 1. 插入第一个节点
133
+ node1_id = "人工智能"
134
+ node1_data = {
135
+ "entity_id": node1_id,
136
+ "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
137
+ "keywords": "AI,机器学习,深度学习",
138
+ "entity_type": "技术领域"
139
+ }
140
+ print(f"插入节点1: {node1_id}")
141
+ await storage.upsert_node(node1_id, node1_data)
142
+
143
+ # 2. 插入第二个节点
144
+ node2_id = "机器学习"
145
+ node2_data = {
146
+ "entity_id": node2_id,
147
+ "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
148
+ "keywords": "监督学习,无监督学习,强化学习",
149
+ "entity_type": "技术领域"
150
+ }
151
+ print(f"插入节点2: {node2_id}")
152
+ await storage.upsert_node(node2_id, node2_data)
153
+
154
+ # 3. 插入连接边
155
+ edge_data = {
156
+ "relationship": "包含",
157
+ "weight": 1.0,
158
+ "description": "人工智能领域包含机器学习这个子领域"
159
+ }
160
+ print(f"插入边: {node1_id} -> {node2_id}")
161
+ await storage.upsert_edge(node1_id, node2_id, edge_data)
162
+
163
+ # 4. 读取节点属性
164
+ print(f"读取节点属性: {node1_id}")
165
+ node1_props = await storage.get_node(node1_id)
166
+ if node1_props:
167
+ print(f"成功读取节点属性: {node1_id}")
168
+ print(f"节点描述: {node1_props.get('description', '无描述')}")
169
+ print(f"节点类型: {node1_props.get('entity_type', '无类型')}")
170
+ print(f"节点关键词: {node1_props.get('keywords', '无关键词')}")
171
+ # 验证返回的属性是否正确
172
+ assert node1_props.get('entity_id') == node1_id, f"节点ID不匹配: 期望 {node1_id}, 实际 {node1_props.get('entity_id')}"
173
+ assert node1_props.get('description') == node1_data['description'], "节点描述不匹配"
174
+ assert node1_props.get('entity_type') == node1_data['entity_type'], "节点类型不匹配"
175
+ else:
176
+ print(f"读取节点属性失败: {node1_id}")
177
+ assert False, f"未能读取节点属性: {node1_id}"
178
+
179
+ # 5. 读取边属性
180
+ print(f"读取边属性: {node1_id} -> {node2_id}")
181
+ edge_props = await storage.get_edge(node1_id, node2_id)
182
+ if edge_props:
183
+ print(f"成功读取边属性: {node1_id} -> {node2_id}")
184
+ print(f"边关系: {edge_props.get('relationship', '无关系')}")
185
+ print(f"边描述: {edge_props.get('description', '无描述')}")
186
+ print(f"边权重: {edge_props.get('weight', '无权重')}")
187
+ # 验证返回的属性是否正确
188
+ assert edge_props.get('relationship') == edge_data['relationship'], "边关系不匹配"
189
+ assert edge_props.get('description') == edge_data['description'], "边描述不匹配"
190
+ assert edge_props.get('weight') == edge_data['weight'], "边权重不匹配"
191
+ else:
192
+ print(f"读取边属性失败: {node1_id} -> {node2_id}")
193
+ assert False, f"未能读取边属性: {node1_id} -> {node2_id}"
194
+
195
+ print("基本测试完成,数据已保留在数据库中")
196
+ return True
197
+
198
+ except Exception as e:
199
+ ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
200
+ return False
201
+
202
+ async def test_graph_advanced(storage):
203
+ """
204
+ 测试图数据库的高级操作:
205
+ 1. 使用 node_degree 获取节点的度数
206
+ 2. 使用 edge_degree 获取边的度数
207
+ 3. 使用 get_node_edges 获取节点的所有边
208
+ 4. 使用 get_all_labels 获取所有标签
209
+ 5. 使用 get_knowledge_graph 获取知识图谱
210
+ 6. 使用 delete_node 删除节点
211
+ 7. 使用 remove_nodes 批量删除节点
212
+ 8. 使用 remove_edges 删除边
213
+ 9. 使用 drop 清理数据
214
+ """
215
+ try:
216
+ # 清理之前的测试数据
217
+ print("清理之前的测试数据...\n")
218
+ await storage.drop()
219
+
220
+ # 1. 插入测试数据
221
+ # 插入节点1: 人工智能
222
+ node1_id = "人工智能"
223
+ node1_data = {
224
+ "entity_id": node1_id,
225
+ "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
226
+ "keywords": "AI,机器学习,深度学习",
227
+ "entity_type": "技术领域"
228
+ }
229
+ print(f"插入节点1: {node1_id}")
230
+ await storage.upsert_node(node1_id, node1_data)
231
+
232
+ # 插入节点2: 机器学习
233
+ node2_id = "机器学习"
234
+ node2_data = {
235
+ "entity_id": node2_id,
236
+ "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
237
+ "keywords": "监督学习,无监督学习,强化学习",
238
+ "entity_type": "技术领域"
239
+ }
240
+ print(f"插入节点2: {node2_id}")
241
+ await storage.upsert_node(node2_id, node2_data)
242
+
243
+ # 插入节点3: 深度学习
244
+ node3_id = "深度学习"
245
+ node3_data = {
246
+ "entity_id": node3_id,
247
+ "description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
248
+ "keywords": "神经网络,CNN,RNN",
249
+ "entity_type": "技术领域"
250
+ }
251
+ print(f"插入节点3: {node3_id}")
252
+ await storage.upsert_node(node3_id, node3_data)
253
+
254
+ # 插入边1: 人工智能 -> 机器学习
255
+ edge1_data = {
256
+ "relationship": "包含",
257
+ "weight": 1.0,
258
+ "description": "人工智能领域包含机器学习这个子领域"
259
+ }
260
+ print(f"插入边1: {node1_id} -> {node2_id}")
261
+ await storage.upsert_edge(node1_id, node2_id, edge1_data)
262
+
263
+ # 插入边2: 机器学习 -> 深度学习
264
+ edge2_data = {
265
+ "relationship": "包含",
266
+ "weight": 1.0,
267
+ "description": "机器学习领域包含深度学习这个子领域"
268
+ }
269
+ print(f"插入边2: {node2_id} -> {node3_id}")
270
+ await storage.upsert_edge(node2_id, node3_id, edge2_data)
271
+
272
+ # 2. 测试 node_degree - 获取节点的度数
273
+ print(f"== 测试 node_degree: {node1_id}")
274
+ node1_degree = await storage.node_degree(node1_id)
275
+ print(f"节点 {node1_id} 的度数: {node1_degree}")
276
+ assert node1_degree == 1, f"节点 {node1_id} 的度数应为1,实际为 {node1_degree}"
277
+
278
+ # 3. 测试 edge_degree - 获取边的度数
279
+ print(f"== 测试 edge_degree: {node1_id} -> {node2_id}")
280
+ edge_degree = await storage.edge_degree(node1_id, node2_id)
281
+ print(f"边 {node1_id} -> {node2_id} 的度数: {edge_degree}")
282
+ assert edge_degree == 3, f"边 {node1_id} -> {node2_id} 的度数应为2,实际为 {edge_degree}"
283
+
284
+ # 4. 测试 get_node_edges - 获取节点的所有边
285
+ print(f"== 测试 get_node_edges: {node2_id}")
286
+ node2_edges = await storage.get_node_edges(node2_id)
287
+ print(f"节点 {node2_id} 的所有边: {node2_edges}")
288
+ assert len(node2_edges) == 2, f"节点 {node2_id} 应有2条边,实际有 {len(node2_edges)}"
289
+
290
+ # 5. 测试 get_all_labels - 获取所有标签
291
+ print("== 测试 get_all_labels")
292
+ all_labels = await storage.get_all_labels()
293
+ print(f"所有标签: {all_labels}")
294
+ assert len(all_labels) == 3, f"应有3个标签,实际有 {len(all_labels)}"
295
+ assert node1_id in all_labels, f"{node1_id} 应在标签列表中"
296
+ assert node2_id in all_labels, f"{node2_id} 应在标签列表中"
297
+ assert node3_id in all_labels, f"{node3_id} 应在标签列表中"
298
+
299
+ # 6. 测试 get_knowledge_graph - 获取知识图谱
300
+ print("== 测试 get_knowledge_graph")
301
+ kg = await storage.get_knowledge_graph("*", max_depth=2, max_nodes=10)
302
+ print(f"知识图谱节点数: {len(kg.nodes)}")
303
+ print(f"知识图谱边数: {len(kg.edges)}")
304
+ assert isinstance(kg, KnowledgeGraph), "返回结果应为 KnowledgeGraph 类型"
305
+ assert len(kg.nodes) == 3, f"知识图谱应有3个节点,实际有 {len(kg.nodes)}"
306
+ assert len(kg.edges) == 2, f"知识图谱应有2条边,实际有 {len(kg.edges)}"
307
+
308
+ # 7. 测试 delete_node - 删除节点
309
+ print(f"== 测试 delete_node: {node3_id}")
310
+ await storage.delete_node(node3_id)
311
+ node3_props = await storage.get_node(node3_id)
312
+ print(f"删除后查询节点属性 {node3_id}: {node3_props}")
313
+ assert node3_props is None, f"节点 {node3_id} 应已被删除"
314
+
315
+ # 重新插入节点3用于后续测试
316
+ await storage.upsert_node(node3_id, node3_data)
317
+ await storage.upsert_edge(node2_id, node3_id, edge2_data)
318
+
319
+ # 8. 测试 remove_edges - 删除边
320
+ print(f"== 测试 remove_edges: {node2_id} -> {node3_id}")
321
+ await storage.remove_edges([(node2_id, node3_id)])
322
+ edge_props = await storage.get_edge(node2_id, node3_id)
323
+ print(f"删除后查询边属性 {node2_id} -> {node3_id}: {edge_props}")
324
+ assert edge_props is None, f"边 {node2_id} -> {node3_id} 应已被删除"
325
+
326
+ # 9. 测试 remove_nodes - 批量删除节点
327
+ print(f"== 测试 remove_nodes: [{node2_id}, {node3_id}]")
328
+ await storage.remove_nodes([node2_id, node3_id])
329
+ node2_props = await storage.get_node(node2_id)
330
+ node3_props = await storage.get_node(node3_id)
331
+ print(f"删除后查询节点属性 {node2_id}: {node2_props}")
332
+ print(f"删除后查询节点属性 {node3_id}: {node3_props}")
333
+ assert node2_props is None, f"节点 {node2_id} 应已被删除"
334
+ assert node3_props is None, f"节点 {node3_id} 应已被删除"
335
+
336
+ # 10. 测试 drop - 清理数据
337
+ print("== 测试 drop")
338
+ result = await storage.drop()
339
+ print(f"清理结果: {result}")
340
+ assert result["status"] == "success", f"清理应成功,实际状态为 {result['status']}"
341
+
342
+ # 验证清理结果
343
+ all_labels = await storage.get_all_labels()
344
+ print(f"清理后的所有标签: {all_labels}")
345
+ assert len(all_labels) == 0, f"清理后应没有标签,实际有 {len(all_labels)}"
346
+
347
+ print("\n高级测试完成")
348
+ return True
349
+
350
+ except Exception as e:
351
+ ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
352
+ return False
353
+
354
+ async def main():
355
+ """主函数"""
356
+ # 显示程序标题
357
+ ASCIIColors.cyan("""
358
+ ╔══════════════════════════════════════════════════════════════╗
359
+ ║ 通用图存储测试程序 ║
360
+ ╚══════════════════════════════════════════════════════════════╝
361
+ """)
362
+
363
+ # 检查.env文件
364
+ if not check_env_file():
365
+ return
366
+
367
+ # 加载环境变量
368
+ load_dotenv(dotenv_path=".env", override=False)
369
+
370
+ # 获取图存储类型
371
+ graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
372
+ ASCIIColors.magenta(f"\n当前配置的图存储类型: {graph_storage_type}")
373
+ ASCIIColors.white(f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}")
374
+
375
+ # 初始化存储实例
376
+ storage = await initialize_graph_storage()
377
+ if not storage:
378
+ ASCIIColors.red("初始化存储实例失败,测试程序退出")
379
+ return
380
+
381
+ try:
382
+ # 显示测试选项
383
+ ASCIIColors.yellow("\n请选择测试类型:")
384
+ ASCIIColors.white("1. 基本测试 (节点和边的插入、读取)")
385
+ ASCIIColors.white("2. 高级测试 (度数、标签、知识图谱、删除操作等)")
386
+ ASCIIColors.white("3. 全部测试")
387
+
388
+ choice = input("\n请输入选项 (1/2/3): ")
389
+
390
+ if choice == "1":
391
+ await test_graph_basic(storage)
392
+ elif choice == "2":
393
+ await test_graph_advanced(storage)
394
+ elif choice == "3":
395
+ ASCIIColors.cyan("\n=== 开始基本测试 ===")
396
+ basic_result = await test_graph_basic(storage)
397
+
398
+ if basic_result:
399
+ ASCIIColors.cyan("\n=== 开始高级测试 ===")
400
+ await test_graph_advanced(storage)
401
+ else:
402
+ ASCIIColors.red("无效的选项")
403
+
404
+ finally:
405
+ # 关闭连接
406
+ if storage:
407
+ await storage.finalize()
408
+ ASCIIColors.green("\n存储连接已关闭")
409
+
410
+ if __name__ == "__main__":
411
+ asyncio.run(main())