yangdx commited on
Commit
a7b9344
·
1 Parent(s): 5cd6509

Fix linting

Browse files
Files changed (1) hide show
  1. tests/test_graph_storage.py +95 -66
tests/test_graph_storage.py CHANGED
@@ -27,14 +27,16 @@ from lightrag.kg import (
27
  STORAGE_IMPLEMENTATIONS,
28
  STORAGE_ENV_REQUIREMENTS,
29
  STORAGES,
30
- verify_storage_implementation
31
  )
32
  from lightrag.kg.shared_storage import initialize_share_data
33
 
 
34
  # 模拟的嵌入函数,返回随机向量
35
  async def mock_embedding_func(texts):
36
  return np.random.rand(len(texts), 10) # 返回10维随机向量
37
 
 
38
  def check_env_file():
39
  """
40
  检查.env文件是否存在,如果不存在则发出警告
@@ -52,6 +54,7 @@ def check_env_file():
52
  return False
53
  return True
54
 
 
55
  async def initialize_graph_storage():
56
  """
57
  根据环境变量初始化相应的图存储实例
@@ -59,56 +62,60 @@ async def initialize_graph_storage():
59
  """
60
  # 从环境变量中获取图存储类型
61
  graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
62
-
63
  # 验证存储类型是否有效
64
  try:
65
  verify_storage_implementation("GRAPH_STORAGE", graph_storage_type)
66
  except ValueError as e:
67
  ASCIIColors.red(f"错误: {str(e)}")
68
- ASCIIColors.yellow(f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}")
 
 
69
  return None
70
-
71
  # 检查所需的环境变量
72
  required_env_vars = STORAGE_ENV_REQUIREMENTS.get(graph_storage_type, [])
73
  missing_env_vars = [var for var in required_env_vars if not os.getenv(var)]
74
-
75
  if missing_env_vars:
76
- ASCIIColors.red(f"错误: {graph_storage_type} 需要以下环境变量,但未设置: {', '.join(missing_env_vars)}")
 
 
77
  return None
78
-
79
  # 动态导入相应的模块
80
  module_path = STORAGES.get(graph_storage_type)
81
  if not module_path:
82
  ASCIIColors.red(f"错误: 未找到 {graph_storage_type} 的模块路径")
83
  return None
84
-
85
  try:
86
  module = importlib.import_module(module_path, package="lightrag")
87
  storage_class = getattr(module, graph_storage_type)
88
  except (ImportError, AttributeError) as e:
89
  ASCIIColors.red(f"错误: 导入 {graph_storage_type} 失败: {str(e)}")
90
  return None
91
-
92
  # 初始化存储实例
93
  global_config = {
94
  "embedding_batch_num": 10, # 批处理大小
95
  "vector_db_storage_cls_kwargs": {
96
  "cosine_better_than_threshold": 0.5 # 余弦相似度阈值
97
  },
98
- "working_dir": os.environ.get("WORKING_DIR", "./rag_storage") # 工作目录
99
  }
100
-
101
  # 如果使用 NetworkXStorage,需要先初始化 shared_storage
102
  if graph_storage_type == "NetworkXStorage":
103
  initialize_share_data() # 使用单进程模式
104
-
105
  try:
106
  storage = storage_class(
107
  namespace="test_graph",
108
  global_config=global_config,
109
- embedding_func=mock_embedding_func
110
  )
111
-
112
  # 初始化连接
113
  await storage.initialize()
114
  return storage
@@ -116,6 +123,7 @@ async def initialize_graph_storage():
116
  ASCIIColors.red(f"错误: 初始化 {graph_storage_type} 失败: {str(e)}")
117
  return None
118
 
 
119
  async def test_graph_basic(storage):
120
  """
121
  测试图数据库的基本操作:
@@ -128,38 +136,38 @@ async def test_graph_basic(storage):
128
  # 清理之前的测试数据
129
  print("清理之前的测试数据...")
130
  await storage.drop()
131
-
132
  # 1. 插入第一个节点
133
  node1_id = "人工智能"
134
  node1_data = {
135
  "entity_id": node1_id,
136
  "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
137
  "keywords": "AI,机器学习,深度学习",
138
- "entity_type": "技术领域"
139
  }
140
  print(f"插入节点1: {node1_id}")
141
  await storage.upsert_node(node1_id, node1_data)
142
-
143
  # 2. 插入第二个节点
144
  node2_id = "机器学习"
145
  node2_data = {
146
  "entity_id": node2_id,
147
  "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
148
  "keywords": "监督学习,无监督学习,强化学习",
149
- "entity_type": "技术领域"
150
  }
151
  print(f"插入节点2: {node2_id}")
152
  await storage.upsert_node(node2_id, node2_data)
153
-
154
  # 3. 插入连接边
155
  edge_data = {
156
  "relationship": "包含",
157
  "weight": 1.0,
158
- "description": "人工智能领域包含机器学习这个子领域"
159
  }
160
  print(f"插入边: {node1_id} -> {node2_id}")
161
  await storage.upsert_edge(node1_id, node2_id, edge_data)
162
-
163
  # 4. 读取节点属性
164
  print(f"读取节点属性: {node1_id}")
165
  node1_props = await storage.get_node(node1_id)
@@ -169,13 +177,19 @@ async def test_graph_basic(storage):
169
  print(f"节点类型: {node1_props.get('entity_type', '无类型')}")
170
  print(f"节点关键词: {node1_props.get('keywords', '无关键词')}")
171
  # 验证返回的属性是否正确
172
- assert node1_props.get('entity_id') == node1_id, f"节点ID不匹配: 期望 {node1_id}, 实际 {node1_props.get('entity_id')}"
173
- assert node1_props.get('description') == node1_data['description'], "节点描述不匹配"
174
- assert node1_props.get('entity_type') == node1_data['entity_type'], "节点类型不匹配"
 
 
 
 
 
 
175
  else:
176
  print(f"读取节点属性失败: {node1_id}")
177
  assert False, f"未能读取节点属性: {node1_id}"
178
-
179
  # 5. 读取边属性
180
  print(f"读取边属性: {node1_id} -> {node2_id}")
181
  edge_props = await storage.get_edge(node1_id, node2_id)
@@ -185,20 +199,25 @@ async def test_graph_basic(storage):
185
  print(f"边描述: {edge_props.get('description', '无描述')}")
186
  print(f"边权重: {edge_props.get('weight', '无权重')}")
187
  # 验证返回的属性是否正确
188
- assert edge_props.get('relationship') == edge_data['relationship'], "边关系不匹配"
189
- assert edge_props.get('description') == edge_data['description'], "边描述不匹配"
190
- assert edge_props.get('weight') == edge_data['weight'], "边权重不匹配"
 
 
 
 
191
  else:
192
  print(f"读取边属性失败: {node1_id} -> {node2_id}")
193
  assert False, f"未能读取边属性: {node1_id} -> {node2_id}"
194
-
195
  print("基本测试完成,数据已保留在数据库中")
196
  return True
197
-
198
  except Exception as e:
199
  ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
200
  return False
201
 
 
202
  async def test_graph_advanced(storage):
203
  """
204
  测试图数据库的高级操作:
@@ -216,7 +235,7 @@ async def test_graph_advanced(storage):
216
  # 清理之前的测试数据
217
  print("清理之前的测试数据...\n")
218
  await storage.drop()
219
-
220
  # 1. 插入测试数据
221
  # 插入节点1: 人工智能
222
  node1_id = "人工智能"
@@ -224,69 +243,73 @@ async def test_graph_advanced(storage):
224
  "entity_id": node1_id,
225
  "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
226
  "keywords": "AI,机器学习,深度学习",
227
- "entity_type": "技术领域"
228
  }
229
  print(f"插入节点1: {node1_id}")
230
  await storage.upsert_node(node1_id, node1_data)
231
-
232
  # 插入节点2: 机器学习
233
  node2_id = "机器学习"
234
  node2_data = {
235
  "entity_id": node2_id,
236
  "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
237
  "keywords": "监督学习,无监督学习,强化学习",
238
- "entity_type": "技术领域"
239
  }
240
  print(f"插入节点2: {node2_id}")
241
  await storage.upsert_node(node2_id, node2_data)
242
-
243
  # 插入节点3: 深度学习
244
  node3_id = "深度学习"
245
  node3_data = {
246
  "entity_id": node3_id,
247
  "description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
248
  "keywords": "神经网络,CNN,RNN",
249
- "entity_type": "技��领域"
250
  }
251
  print(f"插入节点3: {node3_id}")
252
  await storage.upsert_node(node3_id, node3_data)
253
-
254
  # 插入边1: 人工智能 -> 机器学习
255
  edge1_data = {
256
  "relationship": "包含",
257
  "weight": 1.0,
258
- "description": "人工智能领域包含机器学习这个子领域"
259
  }
260
  print(f"插入边1: {node1_id} -> {node2_id}")
261
  await storage.upsert_edge(node1_id, node2_id, edge1_data)
262
-
263
  # 插入边2: 机器学习 -> 深度学习
264
  edge2_data = {
265
  "relationship": "包含",
266
  "weight": 1.0,
267
- "description": "机器学习领域包含深度学习这个子领域"
268
  }
269
  print(f"插入边2: {node2_id} -> {node3_id}")
270
  await storage.upsert_edge(node2_id, node3_id, edge2_data)
271
-
272
  # 2. 测试 node_degree - 获取节点的度数
273
  print(f"== 测试 node_degree: {node1_id}")
274
  node1_degree = await storage.node_degree(node1_id)
275
  print(f"节点 {node1_id} 的度数: {node1_degree}")
276
  assert node1_degree == 1, f"节点 {node1_id} 的度数应为1,实际为 {node1_degree}"
277
-
278
  # 3. 测试 edge_degree - 获取边的度数
279
  print(f"== 测试 edge_degree: {node1_id} -> {node2_id}")
280
  edge_degree = await storage.edge_degree(node1_id, node2_id)
281
  print(f"边 {node1_id} -> {node2_id} 的度数: {edge_degree}")
282
- assert edge_degree == 3, f"边 {node1_id} -> {node2_id} 的度数应为2,实际为 {edge_degree}"
283
-
 
 
284
  # 4. 测试 get_node_edges - 获取节点的所有边
285
  print(f"== 测试 get_node_edges: {node2_id}")
286
  node2_edges = await storage.get_node_edges(node2_id)
287
  print(f"节点 {node2_id} 的所有边: {node2_edges}")
288
- assert len(node2_edges) == 2, f"节点 {node2_id} 应有2条边,实际有 {len(node2_edges)}"
289
-
 
 
290
  # 5. 测试 get_all_labels - 获取所有标签
291
  print("== 测试 get_all_labels")
292
  all_labels = await storage.get_all_labels()
@@ -295,7 +318,7 @@ async def test_graph_advanced(storage):
295
  assert node1_id in all_labels, f"{node1_id} 应在标签列表中"
296
  assert node2_id in all_labels, f"{node2_id} 应在标签列表中"
297
  assert node3_id in all_labels, f"{node3_id} 应在标签列表中"
298
-
299
  # 6. 测试 get_knowledge_graph - 获取知识图谱
300
  print("== 测试 get_knowledge_graph")
301
  kg = await storage.get_knowledge_graph("*", max_depth=2, max_nodes=10)
@@ -304,25 +327,25 @@ async def test_graph_advanced(storage):
304
  assert isinstance(kg, KnowledgeGraph), "返回结果应为 KnowledgeGraph 类型"
305
  assert len(kg.nodes) == 3, f"知识图谱应有3个节点,实际有 {len(kg.nodes)}"
306
  assert len(kg.edges) == 2, f"知识图谱应有2条边,实际有 {len(kg.edges)}"
307
-
308
  # 7. 测试 delete_node - 删除节点
309
  print(f"== 测试 delete_node: {node3_id}")
310
  await storage.delete_node(node3_id)
311
  node3_props = await storage.get_node(node3_id)
312
  print(f"删除后查询节点属性 {node3_id}: {node3_props}")
313
  assert node3_props is None, f"节点 {node3_id} 应已被删除"
314
-
315
  # 重新插入节点3用于后续测试
316
  await storage.upsert_node(node3_id, node3_data)
317
  await storage.upsert_edge(node2_id, node3_id, edge2_data)
318
-
319
  # 8. 测试 remove_edges - 删除边
320
  print(f"== 测试 remove_edges: {node2_id} -> {node3_id}")
321
  await storage.remove_edges([(node2_id, node3_id)])
322
  edge_props = await storage.get_edge(node2_id, node3_id)
323
  print(f"删除后查询边属性 {node2_id} -> {node3_id}: {edge_props}")
324
  assert edge_props is None, f"边 {node2_id} -> {node3_id} 应已被删除"
325
-
326
  # 9. 测试 remove_nodes - 批量删除节点
327
  print(f"== 测试 remove_nodes: [{node2_id}, {node3_id}]")
328
  await storage.remove_nodes([node2_id, node3_id])
@@ -332,25 +355,28 @@ async def test_graph_advanced(storage):
332
  print(f"删除后查询节点属性 {node3_id}: {node3_props}")
333
  assert node2_props is None, f"节点 {node2_id} 应已被删除"
334
  assert node3_props is None, f"节点 {node3_id} 应已被删除"
335
-
336
  # 10. 测试 drop - 清理数据
337
  print("== 测试 drop")
338
  result = await storage.drop()
339
  print(f"清理结果: {result}")
340
- assert result["status"] == "success", f"清理应成功,实际状态为 {result['status']}"
341
-
 
 
342
  # 验证清理结果
343
  all_labels = await storage.get_all_labels()
344
  print(f"清理后的所有标签: {all_labels}")
345
  assert len(all_labels) == 0, f"清理后应没有标签,实际有 {len(all_labels)}"
346
-
347
  print("\n高级测试完成")
348
  return True
349
-
350
  except Exception as e:
351
  ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
352
  return False
353
 
 
354
  async def main():
355
  """主函数"""
356
  # 显示程序标题
@@ -359,34 +385,36 @@ async def main():
359
  ║ 通用图存储测试程序 ║
360
  ╚══════════════════════════════════════════════════════════════╝
361
  """)
362
-
363
  # 检查.env文件
364
  if not check_env_file():
365
  return
366
-
367
  # 加载环境变量
368
  load_dotenv(dotenv_path=".env", override=False)
369
-
370
  # 获取图存储类型
371
  graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
372
  ASCIIColors.magenta(f"\n当前配置的图存储类型: {graph_storage_type}")
373
- ASCIIColors.white(f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}")
374
-
 
 
375
  # 初始化存储实例
376
  storage = await initialize_graph_storage()
377
  if not storage:
378
  ASCIIColors.red("初始化存储实例失败,测试程序退出")
379
  return
380
-
381
  try:
382
  # 显示测试选项
383
  ASCIIColors.yellow("\n请选择测试类型:")
384
  ASCIIColors.white("1. 基本测试 (节点和边的插入、读取)")
385
  ASCIIColors.white("2. 高级测试 (度数、标签、知识图谱、删除操作等)")
386
  ASCIIColors.white("3. 全部测试")
387
-
388
  choice = input("\n请输入选项 (1/2/3): ")
389
-
390
  if choice == "1":
391
  await test_graph_basic(storage)
392
  elif choice == "2":
@@ -394,18 +422,19 @@ async def main():
394
  elif choice == "3":
395
  ASCIIColors.cyan("\n=== 开始基本测试 ===")
396
  basic_result = await test_graph_basic(storage)
397
-
398
  if basic_result:
399
  ASCIIColors.cyan("\n=== 开始高级测试 ===")
400
  await test_graph_advanced(storage)
401
  else:
402
  ASCIIColors.red("无效的选项")
403
-
404
  finally:
405
  # 关闭连接
406
  if storage:
407
  await storage.finalize()
408
  ASCIIColors.green("\n存储连接已关闭")
409
 
 
410
  if __name__ == "__main__":
411
  asyncio.run(main())
 
27
  STORAGE_IMPLEMENTATIONS,
28
  STORAGE_ENV_REQUIREMENTS,
29
  STORAGES,
30
+ verify_storage_implementation,
31
  )
32
  from lightrag.kg.shared_storage import initialize_share_data
33
 
34
+
35
  # 模拟的嵌入函数,返回随机向量
36
  async def mock_embedding_func(texts):
37
  return np.random.rand(len(texts), 10) # 返回10维随机向量
38
 
39
+
40
  def check_env_file():
41
  """
42
  检查.env文件是否存在,如果不存在则发出警告
 
54
  return False
55
  return True
56
 
57
+
58
  async def initialize_graph_storage():
59
  """
60
  根据环境变量初始化相应的图存储实例
 
62
  """
63
  # 从环境变量中获取图存储类型
64
  graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
65
+
66
  # 验证存储类型是否有效
67
  try:
68
  verify_storage_implementation("GRAPH_STORAGE", graph_storage_type)
69
  except ValueError as e:
70
  ASCIIColors.red(f"错误: {str(e)}")
71
+ ASCIIColors.yellow(
72
+ f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
73
+ )
74
  return None
75
+
76
  # 检查所需的环境变量
77
  required_env_vars = STORAGE_ENV_REQUIREMENTS.get(graph_storage_type, [])
78
  missing_env_vars = [var for var in required_env_vars if not os.getenv(var)]
79
+
80
  if missing_env_vars:
81
+ ASCIIColors.red(
82
+ f"错误: {graph_storage_type} 需要以下环境变量,但未设置: {', '.join(missing_env_vars)}"
83
+ )
84
  return None
85
+
86
  # 动态导入相应的模块
87
  module_path = STORAGES.get(graph_storage_type)
88
  if not module_path:
89
  ASCIIColors.red(f"错误: 未找到 {graph_storage_type} 的模块路径")
90
  return None
91
+
92
  try:
93
  module = importlib.import_module(module_path, package="lightrag")
94
  storage_class = getattr(module, graph_storage_type)
95
  except (ImportError, AttributeError) as e:
96
  ASCIIColors.red(f"错误: 导入 {graph_storage_type} 失败: {str(e)}")
97
  return None
98
+
99
  # 初始化存储实例
100
  global_config = {
101
  "embedding_batch_num": 10, # 批处理大小
102
  "vector_db_storage_cls_kwargs": {
103
  "cosine_better_than_threshold": 0.5 # 余弦相似度阈值
104
  },
105
+ "working_dir": os.environ.get("WORKING_DIR", "./rag_storage"), # 工作目录
106
  }
107
+
108
  # 如果使用 NetworkXStorage,需要先初始化 shared_storage
109
  if graph_storage_type == "NetworkXStorage":
110
  initialize_share_data() # 使用单进程模式
111
+
112
  try:
113
  storage = storage_class(
114
  namespace="test_graph",
115
  global_config=global_config,
116
+ embedding_func=mock_embedding_func,
117
  )
118
+
119
  # 初始化连接
120
  await storage.initialize()
121
  return storage
 
123
  ASCIIColors.red(f"错误: 初始化 {graph_storage_type} 失败: {str(e)}")
124
  return None
125
 
126
+
127
  async def test_graph_basic(storage):
128
  """
129
  测试图数据库的基本操作:
 
136
  # 清理之前的测试数据
137
  print("清理之前的测试数据...")
138
  await storage.drop()
139
+
140
  # 1. 插入第一个节点
141
  node1_id = "人工智能"
142
  node1_data = {
143
  "entity_id": node1_id,
144
  "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
145
  "keywords": "AI,机器学习,深度学习",
146
+ "entity_type": "技术领域",
147
  }
148
  print(f"插入节点1: {node1_id}")
149
  await storage.upsert_node(node1_id, node1_data)
150
+
151
  # 2. 插入第二个节点
152
  node2_id = "机器学习"
153
  node2_data = {
154
  "entity_id": node2_id,
155
  "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
156
  "keywords": "监督学习,无监督学习,强化学习",
157
+ "entity_type": "技术领域",
158
  }
159
  print(f"插入节点2: {node2_id}")
160
  await storage.upsert_node(node2_id, node2_data)
161
+
162
  # 3. 插入连接边
163
  edge_data = {
164
  "relationship": "包含",
165
  "weight": 1.0,
166
+ "description": "人工智能领域包含机器学习这个子领域",
167
  }
168
  print(f"插入边: {node1_id} -> {node2_id}")
169
  await storage.upsert_edge(node1_id, node2_id, edge_data)
170
+
171
  # 4. 读取节点属性
172
  print(f"读取节点属性: {node1_id}")
173
  node1_props = await storage.get_node(node1_id)
 
177
  print(f"节点类型: {node1_props.get('entity_type', '无类型')}")
178
  print(f"节点关键词: {node1_props.get('keywords', '无关键词')}")
179
  # 验证返回的属性是否正确
180
+ assert (
181
+ node1_props.get("entity_id") == node1_id
182
+ ), f"节点ID不匹配: 期望 {node1_id}, 实际 {node1_props.get('entity_id')}"
183
+ assert (
184
+ node1_props.get("description") == node1_data["description"]
185
+ ), "节点描述不匹配"
186
+ assert (
187
+ node1_props.get("entity_type") == node1_data["entity_type"]
188
+ ), "节点类型不匹配"
189
  else:
190
  print(f"读取节点属性失败: {node1_id}")
191
  assert False, f"未能读取节点属性: {node1_id}"
192
+
193
  # 5. 读取边属性
194
  print(f"读取边属性: {node1_id} -> {node2_id}")
195
  edge_props = await storage.get_edge(node1_id, node2_id)
 
199
  print(f"边描述: {edge_props.get('description', '无描述')}")
200
  print(f"边权重: {edge_props.get('weight', '无权重')}")
201
  # 验证返回的属性是否正确
202
+ assert (
203
+ edge_props.get("relationship") == edge_data["relationship"]
204
+ ), "边关系不匹配"
205
+ assert (
206
+ edge_props.get("description") == edge_data["description"]
207
+ ), "边描述不匹配"
208
+ assert edge_props.get("weight") == edge_data["weight"], "边权重不匹配"
209
  else:
210
  print(f"读取边属性失败: {node1_id} -> {node2_id}")
211
  assert False, f"未能读取边属性: {node1_id} -> {node2_id}"
212
+
213
  print("基本测试完成,数据已保留在数据库中")
214
  return True
215
+
216
  except Exception as e:
217
  ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
218
  return False
219
 
220
+
221
  async def test_graph_advanced(storage):
222
  """
223
  测试图数据库的高级操作:
 
235
  # 清理之前的测试数据
236
  print("清理之前的测试数据...\n")
237
  await storage.drop()
238
+
239
  # 1. 插入测试数据
240
  # 插入节点1: 人工智能
241
  node1_id = "人工智能"
 
243
  "entity_id": node1_id,
244
  "description": "人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。",
245
  "keywords": "AI,机器学习,深度学习",
246
+ "entity_type": "技术领域",
247
  }
248
  print(f"插入节点1: {node1_id}")
249
  await storage.upsert_node(node1_id, node1_data)
250
+
251
  # 插入节点2: 机器学习
252
  node2_id = "机器学习"
253
  node2_data = {
254
  "entity_id": node2_id,
255
  "description": "机器学习是人工智能的一个分支,它使用统计学方法让计算机系统在不被明确编程的情况下也能够学习。",
256
  "keywords": "监督学习,无监督学习,强化学习",
257
+ "entity_type": "技术领域",
258
  }
259
  print(f"插入节点2: {node2_id}")
260
  await storage.upsert_node(node2_id, node2_data)
261
+
262
  # 插入节点3: 深度学习
263
  node3_id = "深度学习"
264
  node3_data = {
265
  "entity_id": node3_id,
266
  "description": "深度学习是机器学习的一个分支,它使用多层神经网络来模拟人脑的学习过程。",
267
  "keywords": "神经网络,CNN,RNN",
268
+ "entity_type": "技术领域",
269
  }
270
  print(f"插入节点3: {node3_id}")
271
  await storage.upsert_node(node3_id, node3_data)
272
+
273
  # 插入边1: 人工智能 -> 机器学习
274
  edge1_data = {
275
  "relationship": "包含",
276
  "weight": 1.0,
277
+ "description": "人工智能领域包含机器学习这个子领域",
278
  }
279
  print(f"插入边1: {node1_id} -> {node2_id}")
280
  await storage.upsert_edge(node1_id, node2_id, edge1_data)
281
+
282
  # 插入边2: 机器学习 -> 深度学习
283
  edge2_data = {
284
  "relationship": "包含",
285
  "weight": 1.0,
286
+ "description": "机器学习领域包含深度学习这个子领域",
287
  }
288
  print(f"插入边2: {node2_id} -> {node3_id}")
289
  await storage.upsert_edge(node2_id, node3_id, edge2_data)
290
+
291
  # 2. 测试 node_degree - 获取节点的度数
292
  print(f"== 测试 node_degree: {node1_id}")
293
  node1_degree = await storage.node_degree(node1_id)
294
  print(f"节点 {node1_id} 的度数: {node1_degree}")
295
  assert node1_degree == 1, f"节点 {node1_id} 的度数应为1,实际为 {node1_degree}"
296
+
297
  # 3. 测试 edge_degree - 获取边的度数
298
  print(f"== 测试 edge_degree: {node1_id} -> {node2_id}")
299
  edge_degree = await storage.edge_degree(node1_id, node2_id)
300
  print(f"边 {node1_id} -> {node2_id} 的度数: {edge_degree}")
301
+ assert (
302
+ edge_degree == 3
303
+ ), f"边 {node1_id} -> {node2_id} 的度数应为2,实际为 {edge_degree}"
304
+
305
  # 4. 测试 get_node_edges - 获取节点的所有边
306
  print(f"== 测试 get_node_edges: {node2_id}")
307
  node2_edges = await storage.get_node_edges(node2_id)
308
  print(f"节点 {node2_id} 的所有边: {node2_edges}")
309
+ assert (
310
+ len(node2_edges) == 2
311
+ ), f"节点 {node2_id} 应有2条边,实际有 {len(node2_edges)}"
312
+
313
  # 5. 测试 get_all_labels - 获取所有标签
314
  print("== 测试 get_all_labels")
315
  all_labels = await storage.get_all_labels()
 
318
  assert node1_id in all_labels, f"{node1_id} 应在标签列表中"
319
  assert node2_id in all_labels, f"{node2_id} 应在标签列表中"
320
  assert node3_id in all_labels, f"{node3_id} 应在标签列表中"
321
+
322
  # 6. 测试 get_knowledge_graph - 获取知识图谱
323
  print("== 测试 get_knowledge_graph")
324
  kg = await storage.get_knowledge_graph("*", max_depth=2, max_nodes=10)
 
327
  assert isinstance(kg, KnowledgeGraph), "返回结果应为 KnowledgeGraph 类型"
328
  assert len(kg.nodes) == 3, f"知识图谱应有3个节点,实际有 {len(kg.nodes)}"
329
  assert len(kg.edges) == 2, f"知识图谱应有2条边,实际有 {len(kg.edges)}"
330
+
331
  # 7. 测试 delete_node - 删除节点
332
  print(f"== 测试 delete_node: {node3_id}")
333
  await storage.delete_node(node3_id)
334
  node3_props = await storage.get_node(node3_id)
335
  print(f"删除后查询节点属性 {node3_id}: {node3_props}")
336
  assert node3_props is None, f"节点 {node3_id} 应已被删除"
337
+
338
  # 重新插入节点3用于后续测试
339
  await storage.upsert_node(node3_id, node3_data)
340
  await storage.upsert_edge(node2_id, node3_id, edge2_data)
341
+
342
  # 8. 测试 remove_edges - 删除边
343
  print(f"== 测试 remove_edges: {node2_id} -> {node3_id}")
344
  await storage.remove_edges([(node2_id, node3_id)])
345
  edge_props = await storage.get_edge(node2_id, node3_id)
346
  print(f"删除后查询边属性 {node2_id} -> {node3_id}: {edge_props}")
347
  assert edge_props is None, f"边 {node2_id} -> {node3_id} 应已被删除"
348
+
349
  # 9. 测试 remove_nodes - 批量删除节点
350
  print(f"== 测试 remove_nodes: [{node2_id}, {node3_id}]")
351
  await storage.remove_nodes([node2_id, node3_id])
 
355
  print(f"删除后查询节点属性 {node3_id}: {node3_props}")
356
  assert node2_props is None, f"节点 {node2_id} 应已被删除"
357
  assert node3_props is None, f"节点 {node3_id} 应已被删除"
358
+
359
  # 10. 测试 drop - 清理数据
360
  print("== 测试 drop")
361
  result = await storage.drop()
362
  print(f"清理结果: {result}")
363
+ assert (
364
+ result["status"] == "success"
365
+ ), f"清理应成功,实际状态为 {result['status']}"
366
+
367
  # 验证清理结果
368
  all_labels = await storage.get_all_labels()
369
  print(f"清理后的所有标签: {all_labels}")
370
  assert len(all_labels) == 0, f"清理后应没有标签,实际有 {len(all_labels)}"
371
+
372
  print("\n高级测试完成")
373
  return True
374
+
375
  except Exception as e:
376
  ASCIIColors.red(f"测试过程中发生错误: {str(e)}")
377
  return False
378
 
379
+
380
  async def main():
381
  """主函数"""
382
  # 显示程序标题
 
385
  ║ 通用图存储测试程序 ║
386
  ╚══════════════════════════════════════════════════════════════╝
387
  """)
388
+
389
  # 检查.env文件
390
  if not check_env_file():
391
  return
392
+
393
  # 加载环境变量
394
  load_dotenv(dotenv_path=".env", override=False)
395
+
396
  # 获取图存储类型
397
  graph_storage_type = os.getenv("LIGHTRAG_GRAPH_STORAGE", "NetworkXStorage")
398
  ASCIIColors.magenta(f"\n当前配置的图存储类型: {graph_storage_type}")
399
+ ASCIIColors.white(
400
+ f"支持的图存储类型: {', '.join(STORAGE_IMPLEMENTATIONS['GRAPH_STORAGE']['implementations'])}"
401
+ )
402
+
403
  # 初始化存储实例
404
  storage = await initialize_graph_storage()
405
  if not storage:
406
  ASCIIColors.red("初始化存储实例失败,测试程序退出")
407
  return
408
+
409
  try:
410
  # 显示测试选项
411
  ASCIIColors.yellow("\n请选择测试类型:")
412
  ASCIIColors.white("1. 基本测试 (节点和边的插入、读取)")
413
  ASCIIColors.white("2. 高级测试 (度数、标签、知识图谱、删除操作等)")
414
  ASCIIColors.white("3. 全部测试")
415
+
416
  choice = input("\n请输入选项 (1/2/3): ")
417
+
418
  if choice == "1":
419
  await test_graph_basic(storage)
420
  elif choice == "2":
 
422
  elif choice == "3":
423
  ASCIIColors.cyan("\n=== 开始基本测试 ===")
424
  basic_result = await test_graph_basic(storage)
425
+
426
  if basic_result:
427
  ASCIIColors.cyan("\n=== 开始高级测试 ===")
428
  await test_graph_advanced(storage)
429
  else:
430
  ASCIIColors.red("无效的选项")
431
+
432
  finally:
433
  # 关闭连接
434
  if storage:
435
  await storage.finalize()
436
  ASCIIColors.green("\n存储连接已关闭")
437
 
438
+
439
  if __name__ == "__main__":
440
  asyncio.run(main())