yangdx
commited on
Commit
·
83b5082
1
Parent(s):
16e7647
Fix linting
Browse files- lightrag/base.py +14 -8
- lightrag/kg/neo4j_impl.py +43 -23
- lightrag/kg/postgres_impl.py +70 -57
- lightrag/operate.py +17 -14
- lightrag_webui/src/stores/graph.ts +22 -22
- lightrag_webui/src/utils/graphOperations.ts +1 -1
- tests/test_graph_storage.py +82 -25
lightrag/base.py
CHANGED
@@ -363,7 +363,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|
363 |
|
364 |
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
365 |
"""Get nodes as a batch using UNWIND
|
366 |
-
|
367 |
Default implementation fetches nodes one by one.
|
368 |
Override this method for better performance in storage backends
|
369 |
that support batch operations.
|
@@ -377,7 +377,7 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|
377 |
|
378 |
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
379 |
"""Node degrees as a batch using UNWIND
|
380 |
-
|
381 |
Default implementation fetches node degrees one by one.
|
382 |
Override this method for better performance in storage backends
|
383 |
that support batch operations.
|
@@ -388,9 +388,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|
388 |
result[node_id] = degree
|
389 |
return result
|
390 |
|
391 |
-
async def edge_degrees_batch(
|
|
|
|
|
392 |
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch
|
393 |
-
|
394 |
Default implementation calculates edge degrees one by one.
|
395 |
Override this method for better performance in storage backends
|
396 |
that support batch operations.
|
@@ -401,9 +403,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|
401 |
result[(src_id, tgt_id)] = degree
|
402 |
return result
|
403 |
|
404 |
-
async def get_edges_batch(
|
|
|
|
|
405 |
"""Get edges as a batch using UNWIND
|
406 |
-
|
407 |
Default implementation fetches edges one by one.
|
408 |
Override this method for better performance in storage backends
|
409 |
that support batch operations.
|
@@ -417,9 +421,11 @@ class BaseGraphStorage(StorageNameSpace, ABC):
|
|
417 |
result[(src_id, tgt_id)] = edge
|
418 |
return result
|
419 |
|
420 |
-
async def get_nodes_edges_batch(
|
|
|
|
|
421 |
"""Get nodes edges as a batch using UNWIND
|
422 |
-
|
423 |
Default implementation fetches node edges one by one.
|
424 |
Override this method for better performance in storage backends
|
425 |
that support batch operations.
|
|
|
363 |
|
364 |
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
365 |
"""Get nodes as a batch using UNWIND
|
366 |
+
|
367 |
Default implementation fetches nodes one by one.
|
368 |
Override this method for better performance in storage backends
|
369 |
that support batch operations.
|
|
|
377 |
|
378 |
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
379 |
"""Node degrees as a batch using UNWIND
|
380 |
+
|
381 |
Default implementation fetches node degrees one by one.
|
382 |
Override this method for better performance in storage backends
|
383 |
that support batch operations.
|
|
|
388 |
result[node_id] = degree
|
389 |
return result
|
390 |
|
391 |
+
async def edge_degrees_batch(
|
392 |
+
self, edge_pairs: list[tuple[str, str]]
|
393 |
+
) -> dict[tuple[str, str], int]:
|
394 |
"""Edge degrees as a batch using UNWIND also uses node_degrees_batch
|
395 |
+
|
396 |
Default implementation calculates edge degrees one by one.
|
397 |
Override this method for better performance in storage backends
|
398 |
that support batch operations.
|
|
|
403 |
result[(src_id, tgt_id)] = degree
|
404 |
return result
|
405 |
|
406 |
+
async def get_edges_batch(
|
407 |
+
self, pairs: list[dict[str, str]]
|
408 |
+
) -> dict[tuple[str, str], dict]:
|
409 |
"""Get edges as a batch using UNWIND
|
410 |
+
|
411 |
Default implementation fetches edges one by one.
|
412 |
Override this method for better performance in storage backends
|
413 |
that support batch operations.
|
|
|
421 |
result[(src_id, tgt_id)] = edge
|
422 |
return result
|
423 |
|
424 |
+
async def get_nodes_edges_batch(
|
425 |
+
self, node_ids: list[str]
|
426 |
+
) -> dict[str, list[tuple[str, str]]]:
|
427 |
"""Get nodes edges as a batch using UNWIND
|
428 |
+
|
429 |
Default implementation fetches node edges one by one.
|
430 |
Override this method for better performance in storage backends
|
431 |
that support batch operations.
|
lightrag/kg/neo4j_impl.py
CHANGED
@@ -311,10 +311,10 @@ class Neo4JStorage(BaseGraphStorage):
|
|
311 |
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
312 |
"""
|
313 |
Retrieve multiple nodes in one query using UNWIND.
|
314 |
-
|
315 |
Args:
|
316 |
node_ids: List of node entity IDs to fetch.
|
317 |
-
|
318 |
Returns:
|
319 |
A dictionary mapping each node_id to its node data (or None if not found).
|
320 |
"""
|
@@ -334,7 +334,9 @@ class Neo4JStorage(BaseGraphStorage):
|
|
334 |
node_dict = dict(node)
|
335 |
# Remove the 'base' label if present in a 'labels' property
|
336 |
if "labels" in node_dict:
|
337 |
-
node_dict["labels"] = [
|
|
|
|
|
338 |
nodes[entity_id] = node_dict
|
339 |
await result.consume() # Make sure to consume the result fully
|
340 |
return nodes
|
@@ -385,12 +387,12 @@ class Neo4JStorage(BaseGraphStorage):
|
|
385 |
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
386 |
"""
|
387 |
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
388 |
-
|
389 |
Args:
|
390 |
node_ids: List of node labels (entity_id values) to look up.
|
391 |
-
|
392 |
Returns:
|
393 |
-
A dictionary mapping each node_id to its degree (number of relationships).
|
394 |
If a node is not found, its degree will be set to 0.
|
395 |
"""
|
396 |
async with self._driver.session(
|
@@ -407,13 +409,13 @@ class Neo4JStorage(BaseGraphStorage):
|
|
407 |
entity_id = record["entity_id"]
|
408 |
degrees[entity_id] = record["degree"]
|
409 |
await result.consume() # Ensure result is fully consumed
|
410 |
-
|
411 |
# For any node_id that did not return a record, set degree to 0.
|
412 |
for nid in node_ids:
|
413 |
if nid not in degrees:
|
414 |
logger.warning(f"No node found with label '{nid}'")
|
415 |
degrees[nid] = 0
|
416 |
-
|
417 |
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
|
418 |
return degrees
|
419 |
|
@@ -436,25 +438,27 @@ class Neo4JStorage(BaseGraphStorage):
|
|
436 |
|
437 |
degrees = int(src_degree) + int(trg_degree)
|
438 |
return degrees
|
439 |
-
|
440 |
-
async def edge_degrees_batch(
|
|
|
|
|
441 |
"""
|
442 |
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
443 |
in batch using the already implemented node_degrees_batch.
|
444 |
-
|
445 |
Args:
|
446 |
edge_pairs: List of (src, tgt) tuples.
|
447 |
-
|
448 |
Returns:
|
449 |
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
|
450 |
"""
|
451 |
# Collect unique node IDs from all edge pairs.
|
452 |
unique_node_ids = {src for src, _ in edge_pairs}
|
453 |
unique_node_ids.update({tgt for _, tgt in edge_pairs})
|
454 |
-
|
455 |
# Get degrees for all nodes in one go.
|
456 |
degrees = await self.node_degrees_batch(list(unique_node_ids))
|
457 |
-
|
458 |
# Sum up degrees for each edge pair.
|
459 |
edge_degrees = {}
|
460 |
for src, tgt in edge_pairs:
|
@@ -547,13 +551,15 @@ class Neo4JStorage(BaseGraphStorage):
|
|
547 |
)
|
548 |
raise
|
549 |
|
550 |
-
async def get_edges_batch(
|
|
|
|
|
551 |
"""
|
552 |
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
553 |
-
|
554 |
Args:
|
555 |
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
556 |
-
|
557 |
Returns:
|
558 |
A dictionary mapping (src, tgt) tuples to their edge properties.
|
559 |
"""
|
@@ -574,13 +580,23 @@ class Neo4JStorage(BaseGraphStorage):
|
|
574 |
if edges and len(edges) > 0:
|
575 |
edge_props = edges[0] # choose the first if multiple exist
|
576 |
# Ensure required keys exist with defaults
|
577 |
-
for key, default in {
|
|
|
|
|
|
|
|
|
|
|
578 |
if key not in edge_props:
|
579 |
edge_props[key] = default
|
580 |
edges_dict[(src, tgt)] = edge_props
|
581 |
else:
|
582 |
# No edge found – set default edge properties
|
583 |
-
edges_dict[(src, tgt)] = {
|
|
|
|
|
|
|
|
|
|
|
584 |
await result.consume()
|
585 |
return edges_dict
|
586 |
|
@@ -644,17 +660,21 @@ class Neo4JStorage(BaseGraphStorage):
|
|
644 |
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
|
645 |
raise
|
646 |
|
647 |
-
async def get_nodes_edges_batch(
|
|
|
|
|
648 |
"""
|
649 |
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
650 |
-
|
651 |
Args:
|
652 |
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
653 |
-
|
654 |
Returns:
|
655 |
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
656 |
"""
|
657 |
-
async with self._driver.session(
|
|
|
|
|
658 |
query = """
|
659 |
UNWIND $node_ids AS id
|
660 |
MATCH (n:base {entity_id: id})
|
|
|
311 |
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
312 |
"""
|
313 |
Retrieve multiple nodes in one query using UNWIND.
|
314 |
+
|
315 |
Args:
|
316 |
node_ids: List of node entity IDs to fetch.
|
317 |
+
|
318 |
Returns:
|
319 |
A dictionary mapping each node_id to its node data (or None if not found).
|
320 |
"""
|
|
|
334 |
node_dict = dict(node)
|
335 |
# Remove the 'base' label if present in a 'labels' property
|
336 |
if "labels" in node_dict:
|
337 |
+
node_dict["labels"] = [
|
338 |
+
label for label in node_dict["labels"] if label != "base"
|
339 |
+
]
|
340 |
nodes[entity_id] = node_dict
|
341 |
await result.consume() # Make sure to consume the result fully
|
342 |
return nodes
|
|
|
387 |
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
388 |
"""
|
389 |
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
390 |
+
|
391 |
Args:
|
392 |
node_ids: List of node labels (entity_id values) to look up.
|
393 |
+
|
394 |
Returns:
|
395 |
+
A dictionary mapping each node_id to its degree (number of relationships).
|
396 |
If a node is not found, its degree will be set to 0.
|
397 |
"""
|
398 |
async with self._driver.session(
|
|
|
409 |
entity_id = record["entity_id"]
|
410 |
degrees[entity_id] = record["degree"]
|
411 |
await result.consume() # Ensure result is fully consumed
|
412 |
+
|
413 |
# For any node_id that did not return a record, set degree to 0.
|
414 |
for nid in node_ids:
|
415 |
if nid not in degrees:
|
416 |
logger.warning(f"No node found with label '{nid}'")
|
417 |
degrees[nid] = 0
|
418 |
+
|
419 |
logger.debug(f"Neo4j batch node degree query returned: {degrees}")
|
420 |
return degrees
|
421 |
|
|
|
438 |
|
439 |
degrees = int(src_degree) + int(trg_degree)
|
440 |
return degrees
|
441 |
+
|
442 |
+
async def edge_degrees_batch(
|
443 |
+
self, edge_pairs: list[tuple[str, str]]
|
444 |
+
) -> dict[tuple[str, str], int]:
|
445 |
"""
|
446 |
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
447 |
in batch using the already implemented node_degrees_batch.
|
448 |
+
|
449 |
Args:
|
450 |
edge_pairs: List of (src, tgt) tuples.
|
451 |
+
|
452 |
Returns:
|
453 |
A dictionary mapping each (src, tgt) tuple to the sum of their degrees.
|
454 |
"""
|
455 |
# Collect unique node IDs from all edge pairs.
|
456 |
unique_node_ids = {src for src, _ in edge_pairs}
|
457 |
unique_node_ids.update({tgt for _, tgt in edge_pairs})
|
458 |
+
|
459 |
# Get degrees for all nodes in one go.
|
460 |
degrees = await self.node_degrees_batch(list(unique_node_ids))
|
461 |
+
|
462 |
# Sum up degrees for each edge pair.
|
463 |
edge_degrees = {}
|
464 |
for src, tgt in edge_pairs:
|
|
|
551 |
)
|
552 |
raise
|
553 |
|
554 |
+
async def get_edges_batch(
|
555 |
+
self, pairs: list[dict[str, str]]
|
556 |
+
) -> dict[tuple[str, str], dict]:
|
557 |
"""
|
558 |
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
559 |
+
|
560 |
Args:
|
561 |
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
562 |
+
|
563 |
Returns:
|
564 |
A dictionary mapping (src, tgt) tuples to their edge properties.
|
565 |
"""
|
|
|
580 |
if edges and len(edges) > 0:
|
581 |
edge_props = edges[0] # choose the first if multiple exist
|
582 |
# Ensure required keys exist with defaults
|
583 |
+
for key, default in {
|
584 |
+
"weight": 0.0,
|
585 |
+
"source_id": None,
|
586 |
+
"description": None,
|
587 |
+
"keywords": None,
|
588 |
+
}.items():
|
589 |
if key not in edge_props:
|
590 |
edge_props[key] = default
|
591 |
edges_dict[(src, tgt)] = edge_props
|
592 |
else:
|
593 |
# No edge found – set default edge properties
|
594 |
+
edges_dict[(src, tgt)] = {
|
595 |
+
"weight": 0.0,
|
596 |
+
"source_id": None,
|
597 |
+
"description": None,
|
598 |
+
"keywords": None,
|
599 |
+
}
|
600 |
await result.consume()
|
601 |
return edges_dict
|
602 |
|
|
|
660 |
logger.error(f"Error in get_node_edges for {source_node_id}: {str(e)}")
|
661 |
raise
|
662 |
|
663 |
+
async def get_nodes_edges_batch(
|
664 |
+
self, node_ids: list[str]
|
665 |
+
) -> dict[str, list[tuple[str, str]]]:
|
666 |
"""
|
667 |
Batch retrieve edges for multiple nodes in one query using UNWIND.
|
668 |
+
|
669 |
Args:
|
670 |
node_ids: List of node IDs (entity_id) for which to retrieve edges.
|
671 |
+
|
672 |
Returns:
|
673 |
A dictionary mapping each node ID to its list of edge tuples (source, target).
|
674 |
"""
|
675 |
+
async with self._driver.session(
|
676 |
+
database=self._DATABASE, default_access_mode="READ"
|
677 |
+
) as session:
|
678 |
query = """
|
679 |
UNWIND $node_ids AS id
|
680 |
MATCH (n:base {entity_id: id})
|
lightrag/kg/postgres_impl.py
CHANGED
@@ -1461,30 +1461,29 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1461 |
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
1462 |
"""
|
1463 |
Retrieve multiple nodes in one query using UNWIND.
|
1464 |
-
|
1465 |
Args:
|
1466 |
node_ids: List of node entity IDs to fetch.
|
1467 |
-
|
1468 |
Returns:
|
1469 |
A dictionary mapping each node_id to its node data (or None if not found).
|
1470 |
"""
|
1471 |
if not node_ids:
|
1472 |
return {}
|
1473 |
-
|
1474 |
# Format node IDs for the query
|
1475 |
-
formatted_ids = ", ".join(
|
1476 |
-
|
|
|
|
|
1477 |
query = """SELECT * FROM cypher('%s', $$
|
1478 |
UNWIND [%s] AS node_id
|
1479 |
MATCH (n:base {entity_id: node_id})
|
1480 |
RETURN node_id, n
|
1481 |
-
$$) AS (node_id text, n agtype)""" % (
|
1482 |
-
|
1483 |
-
formatted_ids
|
1484 |
-
)
|
1485 |
-
|
1486 |
results = await self._query(query)
|
1487 |
-
|
1488 |
# Build result dictionary
|
1489 |
nodes_dict = {}
|
1490 |
for result in results:
|
@@ -1492,28 +1491,32 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1492 |
node_dict = result["n"]["properties"]
|
1493 |
# Remove the 'base' label if present in a 'labels' property
|
1494 |
if "labels" in node_dict:
|
1495 |
-
node_dict["labels"] = [
|
|
|
|
|
1496 |
nodes_dict[result["node_id"]] = node_dict
|
1497 |
-
|
1498 |
return nodes_dict
|
1499 |
|
1500 |
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
1501 |
"""
|
1502 |
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
1503 |
-
|
1504 |
Args:
|
1505 |
node_ids: List of node labels (entity_id values) to look up.
|
1506 |
-
|
1507 |
Returns:
|
1508 |
-
A dictionary mapping each node_id to its degree (number of relationships).
|
1509 |
If a node is not found, its degree will be set to 0.
|
1510 |
"""
|
1511 |
if not node_ids:
|
1512 |
return {}
|
1513 |
-
|
1514 |
# Format node IDs for the query
|
1515 |
-
formatted_ids = ", ".join(
|
1516 |
-
|
|
|
|
|
1517 |
query = """SELECT * FROM cypher('%s', $$
|
1518 |
UNWIND [%s] AS node_id
|
1519 |
MATCH (n:base {entity_id: node_id})
|
@@ -1521,112 +1524,122 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1521 |
RETURN node_id, count(r) AS degree
|
1522 |
$$) AS (node_id text, degree bigint)""" % (
|
1523 |
self.graph_name,
|
1524 |
-
formatted_ids
|
1525 |
)
|
1526 |
-
|
1527 |
results = await self._query(query)
|
1528 |
-
|
1529 |
# Build result dictionary
|
1530 |
degrees_dict = {}
|
1531 |
for result in results:
|
1532 |
if result["node_id"] is not None:
|
1533 |
degrees_dict[result["node_id"]] = int(result["degree"])
|
1534 |
-
|
1535 |
# Ensure all requested node_ids are in the result dictionary
|
1536 |
for node_id in node_ids:
|
1537 |
if node_id not in degrees_dict:
|
1538 |
degrees_dict[node_id] = 0
|
1539 |
-
|
1540 |
return degrees_dict
|
1541 |
-
|
1542 |
-
async def edge_degrees_batch(
|
|
|
|
|
1543 |
"""
|
1544 |
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
1545 |
in batch using the already implemented node_degrees_batch.
|
1546 |
-
|
1547 |
Args:
|
1548 |
edges: List of (source_node_id, target_node_id) tuples
|
1549 |
-
|
1550 |
Returns:
|
1551 |
Dictionary mapping edge tuples to their combined degrees
|
1552 |
"""
|
1553 |
if not edges:
|
1554 |
return {}
|
1555 |
-
|
1556 |
# Use node_degrees_batch to get all node degrees efficiently
|
1557 |
all_nodes = set()
|
1558 |
for src, tgt in edges:
|
1559 |
all_nodes.add(src)
|
1560 |
all_nodes.add(tgt)
|
1561 |
-
|
1562 |
node_degrees = await self.node_degrees_batch(list(all_nodes))
|
1563 |
-
|
1564 |
# Calculate edge degrees
|
1565 |
edge_degrees_dict = {}
|
1566 |
for src, tgt in edges:
|
1567 |
src_degree = node_degrees.get(src, 0)
|
1568 |
tgt_degree = node_degrees.get(tgt, 0)
|
1569 |
edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
|
1570 |
-
|
1571 |
return edge_degrees_dict
|
1572 |
-
|
1573 |
-
async def get_edges_batch(
|
|
|
|
|
1574 |
"""
|
1575 |
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
1576 |
-
|
1577 |
Args:
|
1578 |
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
1579 |
-
|
1580 |
Returns:
|
1581 |
A dictionary mapping (src, tgt) tuples to their edge properties.
|
1582 |
"""
|
1583 |
if not pairs:
|
1584 |
return {}
|
1585 |
-
|
1586 |
# 从字典列表中提取源节点和目标节点ID
|
1587 |
src_nodes = []
|
1588 |
tgt_nodes = []
|
1589 |
for pair in pairs:
|
1590 |
-
src_nodes.append(pair["src"].replace('"',
|
1591 |
-
tgt_nodes.append(pair["tgt"].replace('"',
|
1592 |
-
|
1593 |
# 构建查询,使用数组索引来匹配源节点和目标节点
|
1594 |
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
1595 |
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
1596 |
-
|
1597 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1598 |
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
1599 |
UNWIND range(0, size(sources)-1) AS i
|
1600 |
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
|
1601 |
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
1602 |
$$) AS (source text, target text, edge_properties agtype)"""
|
1603 |
-
|
1604 |
results = await self._query(query)
|
1605 |
-
|
1606 |
# 构建结果字典
|
1607 |
edges_dict = {}
|
1608 |
for result in results:
|
1609 |
if result["source"] and result["target"] and result["edge_properties"]:
|
1610 |
-
edges_dict[(result["source"], result["target"])] = result[
|
1611 |
-
|
|
|
|
|
1612 |
return edges_dict
|
1613 |
-
|
1614 |
-
async def get_nodes_edges_batch(
|
|
|
|
|
1615 |
"""
|
1616 |
Get all edges for multiple nodes in a single batch operation.
|
1617 |
-
|
1618 |
Args:
|
1619 |
node_ids: List of node IDs to get edges for
|
1620 |
-
|
1621 |
Returns:
|
1622 |
Dictionary mapping node IDs to lists of (source, target) edge tuples
|
1623 |
"""
|
1624 |
if not node_ids:
|
1625 |
return {}
|
1626 |
-
|
1627 |
# Format node IDs for the query
|
1628 |
-
formatted_ids = ", ".join(
|
1629 |
-
|
|
|
|
|
1630 |
query = """SELECT * FROM cypher('%s', $$
|
1631 |
UNWIND [%s] AS node_id
|
1632 |
MATCH (n:base {entity_id: node_id})
|
@@ -1634,11 +1647,11 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1634 |
RETURN node_id, connected.entity_id AS connected_id
|
1635 |
$$) AS (node_id text, connected_id text)""" % (
|
1636 |
self.graph_name,
|
1637 |
-
formatted_ids
|
1638 |
)
|
1639 |
-
|
1640 |
results = await self._query(query)
|
1641 |
-
|
1642 |
# Build result dictionary
|
1643 |
nodes_edges_dict = {node_id: [] for node_id in node_ids}
|
1644 |
for result in results:
|
@@ -1646,9 +1659,9 @@ class PGGraphStorage(BaseGraphStorage):
|
|
1646 |
nodes_edges_dict[result["node_id"]].append(
|
1647 |
(result["node_id"], result["connected_id"])
|
1648 |
)
|
1649 |
-
|
1650 |
return nodes_edges_dict
|
1651 |
-
|
1652 |
async def get_all_labels(self) -> list[str]:
|
1653 |
"""
|
1654 |
Get all labels (node IDs) in the graph.
|
|
|
1461 |
async def get_nodes_batch(self, node_ids: list[str]) -> dict[str, dict]:
|
1462 |
"""
|
1463 |
Retrieve multiple nodes in one query using UNWIND.
|
1464 |
+
|
1465 |
Args:
|
1466 |
node_ids: List of node entity IDs to fetch.
|
1467 |
+
|
1468 |
Returns:
|
1469 |
A dictionary mapping each node_id to its node data (or None if not found).
|
1470 |
"""
|
1471 |
if not node_ids:
|
1472 |
return {}
|
1473 |
+
|
1474 |
# Format node IDs for the query
|
1475 |
+
formatted_ids = ", ".join(
|
1476 |
+
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
1477 |
+
)
|
1478 |
+
|
1479 |
query = """SELECT * FROM cypher('%s', $$
|
1480 |
UNWIND [%s] AS node_id
|
1481 |
MATCH (n:base {entity_id: node_id})
|
1482 |
RETURN node_id, n
|
1483 |
+
$$) AS (node_id text, n agtype)""" % (self.graph_name, formatted_ids)
|
1484 |
+
|
|
|
|
|
|
|
1485 |
results = await self._query(query)
|
1486 |
+
|
1487 |
# Build result dictionary
|
1488 |
nodes_dict = {}
|
1489 |
for result in results:
|
|
|
1491 |
node_dict = result["n"]["properties"]
|
1492 |
# Remove the 'base' label if present in a 'labels' property
|
1493 |
if "labels" in node_dict:
|
1494 |
+
node_dict["labels"] = [
|
1495 |
+
label for label in node_dict["labels"] if label != "base"
|
1496 |
+
]
|
1497 |
nodes_dict[result["node_id"]] = node_dict
|
1498 |
+
|
1499 |
return nodes_dict
|
1500 |
|
1501 |
async def node_degrees_batch(self, node_ids: list[str]) -> dict[str, int]:
|
1502 |
"""
|
1503 |
Retrieve the degree for multiple nodes in a single query using UNWIND.
|
1504 |
+
|
1505 |
Args:
|
1506 |
node_ids: List of node labels (entity_id values) to look up.
|
1507 |
+
|
1508 |
Returns:
|
1509 |
+
A dictionary mapping each node_id to its degree (number of relationships).
|
1510 |
If a node is not found, its degree will be set to 0.
|
1511 |
"""
|
1512 |
if not node_ids:
|
1513 |
return {}
|
1514 |
+
|
1515 |
# Format node IDs for the query
|
1516 |
+
formatted_ids = ", ".join(
|
1517 |
+
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
1518 |
+
)
|
1519 |
+
|
1520 |
query = """SELECT * FROM cypher('%s', $$
|
1521 |
UNWIND [%s] AS node_id
|
1522 |
MATCH (n:base {entity_id: node_id})
|
|
|
1524 |
RETURN node_id, count(r) AS degree
|
1525 |
$$) AS (node_id text, degree bigint)""" % (
|
1526 |
self.graph_name,
|
1527 |
+
formatted_ids,
|
1528 |
)
|
1529 |
+
|
1530 |
results = await self._query(query)
|
1531 |
+
|
1532 |
# Build result dictionary
|
1533 |
degrees_dict = {}
|
1534 |
for result in results:
|
1535 |
if result["node_id"] is not None:
|
1536 |
degrees_dict[result["node_id"]] = int(result["degree"])
|
1537 |
+
|
1538 |
# Ensure all requested node_ids are in the result dictionary
|
1539 |
for node_id in node_ids:
|
1540 |
if node_id not in degrees_dict:
|
1541 |
degrees_dict[node_id] = 0
|
1542 |
+
|
1543 |
return degrees_dict
|
1544 |
+
|
1545 |
+
async def edge_degrees_batch(
|
1546 |
+
self, edges: list[tuple[str, str]]
|
1547 |
+
) -> dict[tuple[str, str], int]:
|
1548 |
"""
|
1549 |
Calculate the combined degree for each edge (sum of the source and target node degrees)
|
1550 |
in batch using the already implemented node_degrees_batch.
|
1551 |
+
|
1552 |
Args:
|
1553 |
edges: List of (source_node_id, target_node_id) tuples
|
1554 |
+
|
1555 |
Returns:
|
1556 |
Dictionary mapping edge tuples to their combined degrees
|
1557 |
"""
|
1558 |
if not edges:
|
1559 |
return {}
|
1560 |
+
|
1561 |
# Use node_degrees_batch to get all node degrees efficiently
|
1562 |
all_nodes = set()
|
1563 |
for src, tgt in edges:
|
1564 |
all_nodes.add(src)
|
1565 |
all_nodes.add(tgt)
|
1566 |
+
|
1567 |
node_degrees = await self.node_degrees_batch(list(all_nodes))
|
1568 |
+
|
1569 |
# Calculate edge degrees
|
1570 |
edge_degrees_dict = {}
|
1571 |
for src, tgt in edges:
|
1572 |
src_degree = node_degrees.get(src, 0)
|
1573 |
tgt_degree = node_degrees.get(tgt, 0)
|
1574 |
edge_degrees_dict[(src, tgt)] = src_degree + tgt_degree
|
1575 |
+
|
1576 |
return edge_degrees_dict
|
1577 |
+
|
1578 |
+
async def get_edges_batch(
|
1579 |
+
self, pairs: list[dict[str, str]]
|
1580 |
+
) -> dict[tuple[str, str], dict]:
|
1581 |
"""
|
1582 |
Retrieve edge properties for multiple (src, tgt) pairs in one query.
|
1583 |
+
|
1584 |
Args:
|
1585 |
pairs: List of dictionaries, e.g. [{"src": "node1", "tgt": "node2"}, ...]
|
1586 |
+
|
1587 |
Returns:
|
1588 |
A dictionary mapping (src, tgt) tuples to their edge properties.
|
1589 |
"""
|
1590 |
if not pairs:
|
1591 |
return {}
|
1592 |
+
|
1593 |
# 从字典列表中提取源节点和目标节点ID
|
1594 |
src_nodes = []
|
1595 |
tgt_nodes = []
|
1596 |
for pair in pairs:
|
1597 |
+
src_nodes.append(pair["src"].replace('"', ""))
|
1598 |
+
tgt_nodes.append(pair["tgt"].replace('"', ""))
|
1599 |
+
|
1600 |
# 构建查询,使用数组索引来匹配源节点和目标节点
|
1601 |
src_array = ", ".join([f'"{src}"' for src in src_nodes])
|
1602 |
tgt_array = ", ".join([f'"{tgt}"' for tgt in tgt_nodes])
|
1603 |
+
|
1604 |
query = f"""SELECT * FROM cypher('{self.graph_name}', $$
|
1605 |
WITH [{src_array}] AS sources, [{tgt_array}] AS targets
|
1606 |
UNWIND range(0, size(sources)-1) AS i
|
1607 |
MATCH (a:base {{entity_id: sources[i]}})-[r:DIRECTED]-(b:base {{entity_id: targets[i]}})
|
1608 |
RETURN sources[i] AS source, targets[i] AS target, properties(r) AS edge_properties
|
1609 |
$$) AS (source text, target text, edge_properties agtype)"""
|
1610 |
+
|
1611 |
results = await self._query(query)
|
1612 |
+
|
1613 |
# 构建结果字典
|
1614 |
edges_dict = {}
|
1615 |
for result in results:
|
1616 |
if result["source"] and result["target"] and result["edge_properties"]:
|
1617 |
+
edges_dict[(result["source"], result["target"])] = result[
|
1618 |
+
"edge_properties"
|
1619 |
+
]
|
1620 |
+
|
1621 |
return edges_dict
|
1622 |
+
|
1623 |
+
async def get_nodes_edges_batch(
|
1624 |
+
self, node_ids: list[str]
|
1625 |
+
) -> dict[str, list[tuple[str, str]]]:
|
1626 |
"""
|
1627 |
Get all edges for multiple nodes in a single batch operation.
|
1628 |
+
|
1629 |
Args:
|
1630 |
node_ids: List of node IDs to get edges for
|
1631 |
+
|
1632 |
Returns:
|
1633 |
Dictionary mapping node IDs to lists of (source, target) edge tuples
|
1634 |
"""
|
1635 |
if not node_ids:
|
1636 |
return {}
|
1637 |
+
|
1638 |
# Format node IDs for the query
|
1639 |
+
formatted_ids = ", ".join(
|
1640 |
+
['"' + node_id.replace('"', "") + '"' for node_id in node_ids]
|
1641 |
+
)
|
1642 |
+
|
1643 |
query = """SELECT * FROM cypher('%s', $$
|
1644 |
UNWIND [%s] AS node_id
|
1645 |
MATCH (n:base {entity_id: node_id})
|
|
|
1647 |
RETURN node_id, connected.entity_id AS connected_id
|
1648 |
$$) AS (node_id text, connected_id text)""" % (
|
1649 |
self.graph_name,
|
1650 |
+
formatted_ids,
|
1651 |
)
|
1652 |
+
|
1653 |
results = await self._query(query)
|
1654 |
+
|
1655 |
# Build result dictionary
|
1656 |
nodes_edges_dict = {node_id: [] for node_id in node_ids}
|
1657 |
for result in results:
|
|
|
1659 |
nodes_edges_dict[result["node_id"]].append(
|
1660 |
(result["node_id"], result["connected_id"])
|
1661 |
)
|
1662 |
+
|
1663 |
return nodes_edges_dict
|
1664 |
+
|
1665 |
async def get_all_labels(self) -> list[str]:
|
1666 |
"""
|
1667 |
Get all labels (node IDs) in the graph.
|
lightrag/operate.py
CHANGED
@@ -1323,14 +1323,14 @@ async def _get_node_data(
|
|
1323 |
|
1324 |
if not len(results):
|
1325 |
return "", "", ""
|
1326 |
-
|
1327 |
# Extract all entity IDs from your results list
|
1328 |
node_ids = [r["entity_name"] for r in results]
|
1329 |
|
1330 |
# Call the batch node retrieval and degree functions concurrently.
|
1331 |
nodes_dict, degrees_dict = await asyncio.gather(
|
1332 |
-
knowledge_graph_inst.get_nodes_batch(node_ids),
|
1333 |
-
knowledge_graph_inst.node_degrees_batch(node_ids)
|
1334 |
)
|
1335 |
|
1336 |
# Now, if you need the node data and degree in order:
|
@@ -1459,7 +1459,7 @@ async def _find_most_related_text_unit_from_entities(
|
|
1459 |
for dp in node_datas
|
1460 |
if dp["source_id"] is not None
|
1461 |
]
|
1462 |
-
|
1463 |
node_names = [dp["entity_name"] for dp in node_datas]
|
1464 |
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
|
1465 |
# Build the edges list in the same order as node_datas.
|
@@ -1472,10 +1472,14 @@ async def _find_most_related_text_unit_from_entities(
|
|
1472 |
all_one_hop_nodes.update([e[1] for e in this_edges])
|
1473 |
|
1474 |
all_one_hop_nodes = list(all_one_hop_nodes)
|
1475 |
-
|
1476 |
# Batch retrieve one-hop node data using get_nodes_batch
|
1477 |
-
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(
|
1478 |
-
|
|
|
|
|
|
|
|
|
1479 |
|
1480 |
# Add null check for node data
|
1481 |
all_one_hop_text_units_lookup = {
|
@@ -1571,13 +1575,13 @@ async def _find_most_related_edges_from_entities(
|
|
1571 |
edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
|
1572 |
# For edge degrees, use tuples.
|
1573 |
edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
|
1574 |
-
|
1575 |
# Call the batched functions concurrently.
|
1576 |
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
1577 |
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
1578 |
-
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
|
1579 |
)
|
1580 |
-
|
1581 |
# Reconstruct edge_datas list in the same order as the deduplicated results.
|
1582 |
all_edges_data = []
|
1583 |
for pair in all_edges:
|
@@ -1590,7 +1594,6 @@ async def _find_most_related_edges_from_entities(
|
|
1590 |
}
|
1591 |
all_edges_data.append(combined)
|
1592 |
|
1593 |
-
|
1594 |
all_edges_data = sorted(
|
1595 |
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1596 |
)
|
@@ -1634,7 +1637,7 @@ async def _get_edge_data(
|
|
1634 |
# Call the batched functions concurrently.
|
1635 |
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
1636 |
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
1637 |
-
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples)
|
1638 |
)
|
1639 |
|
1640 |
# Reconstruct edge_datas list in the same order as results.
|
@@ -1652,7 +1655,7 @@ async def _get_edge_data(
|
|
1652 |
**edge_props,
|
1653 |
}
|
1654 |
edge_datas.append(combined)
|
1655 |
-
|
1656 |
edge_datas = sorted(
|
1657 |
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1658 |
)
|
@@ -1761,7 +1764,7 @@ async def _find_most_related_entities_from_relationships(
|
|
1761 |
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
|
1762 |
nodes_dict, degrees_dict = await asyncio.gather(
|
1763 |
knowledge_graph_inst.get_nodes_batch(entity_names),
|
1764 |
-
knowledge_graph_inst.node_degrees_batch(entity_names)
|
1765 |
)
|
1766 |
|
1767 |
# Rebuild the list in the same order as entity_names
|
|
|
1323 |
|
1324 |
if not len(results):
|
1325 |
return "", "", ""
|
1326 |
+
|
1327 |
# Extract all entity IDs from your results list
|
1328 |
node_ids = [r["entity_name"] for r in results]
|
1329 |
|
1330 |
# Call the batch node retrieval and degree functions concurrently.
|
1331 |
nodes_dict, degrees_dict = await asyncio.gather(
|
1332 |
+
knowledge_graph_inst.get_nodes_batch(node_ids),
|
1333 |
+
knowledge_graph_inst.node_degrees_batch(node_ids),
|
1334 |
)
|
1335 |
|
1336 |
# Now, if you need the node data and degree in order:
|
|
|
1459 |
for dp in node_datas
|
1460 |
if dp["source_id"] is not None
|
1461 |
]
|
1462 |
+
|
1463 |
node_names = [dp["entity_name"] for dp in node_datas]
|
1464 |
batch_edges_dict = await knowledge_graph_inst.get_nodes_edges_batch(node_names)
|
1465 |
# Build the edges list in the same order as node_datas.
|
|
|
1472 |
all_one_hop_nodes.update([e[1] for e in this_edges])
|
1473 |
|
1474 |
all_one_hop_nodes = list(all_one_hop_nodes)
|
1475 |
+
|
1476 |
# Batch retrieve one-hop node data using get_nodes_batch
|
1477 |
+
all_one_hop_nodes_data_dict = await knowledge_graph_inst.get_nodes_batch(
|
1478 |
+
all_one_hop_nodes
|
1479 |
+
)
|
1480 |
+
all_one_hop_nodes_data = [
|
1481 |
+
all_one_hop_nodes_data_dict.get(e) for e in all_one_hop_nodes
|
1482 |
+
]
|
1483 |
|
1484 |
# Add null check for node data
|
1485 |
all_one_hop_text_units_lookup = {
|
|
|
1575 |
edge_pairs_dicts = [{"src": e[0], "tgt": e[1]} for e in all_edges]
|
1576 |
# For edge degrees, use tuples.
|
1577 |
edge_pairs_tuples = list(all_edges) # all_edges is already a list of tuples
|
1578 |
+
|
1579 |
# Call the batched functions concurrently.
|
1580 |
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
1581 |
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
1582 |
+
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
|
1583 |
)
|
1584 |
+
|
1585 |
# Reconstruct edge_datas list in the same order as the deduplicated results.
|
1586 |
all_edges_data = []
|
1587 |
for pair in all_edges:
|
|
|
1594 |
}
|
1595 |
all_edges_data.append(combined)
|
1596 |
|
|
|
1597 |
all_edges_data = sorted(
|
1598 |
all_edges_data, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1599 |
)
|
|
|
1637 |
# Call the batched functions concurrently.
|
1638 |
edge_data_dict, edge_degrees_dict = await asyncio.gather(
|
1639 |
knowledge_graph_inst.get_edges_batch(edge_pairs_dicts),
|
1640 |
+
knowledge_graph_inst.edge_degrees_batch(edge_pairs_tuples),
|
1641 |
)
|
1642 |
|
1643 |
# Reconstruct edge_datas list in the same order as results.
|
|
|
1655 |
**edge_props,
|
1656 |
}
|
1657 |
edge_datas.append(combined)
|
1658 |
+
|
1659 |
edge_datas = sorted(
|
1660 |
edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
|
1661 |
)
|
|
|
1764 |
# Batch approach: Retrieve nodes and their degrees concurrently with one query each.
|
1765 |
nodes_dict, degrees_dict = await asyncio.gather(
|
1766 |
knowledge_graph_inst.get_nodes_batch(entity_names),
|
1767 |
+
knowledge_graph_inst.node_degrees_batch(entity_names),
|
1768 |
)
|
1769 |
|
1770 |
# Rebuild the list in the same order as entity_names
|
lightrag_webui/src/stores/graph.ts
CHANGED
@@ -136,7 +136,7 @@ interface GraphState {
|
|
136 |
// Version counter to trigger data refresh
|
137 |
graphDataVersion: number
|
138 |
incrementGraphDataVersion: () => void
|
139 |
-
|
140 |
// Methods for updating graph elements and UI state together
|
141 |
updateNodeAndSelect: (nodeId: string, entityId: string, propertyName: string, newValue: string) => Promise<void>
|
142 |
updateEdgeAndSelect: (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => Promise<void>
|
@@ -252,40 +252,40 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
|
252 |
// Get current state
|
253 |
const state = get()
|
254 |
const { sigmaGraph, rawGraph } = state
|
255 |
-
|
256 |
// Validate graph state
|
257 |
if (!sigmaGraph || !rawGraph || !sigmaGraph.hasNode(nodeId)) {
|
258 |
return
|
259 |
}
|
260 |
-
|
261 |
try {
|
262 |
const nodeAttributes = sigmaGraph.getNodeAttributes(nodeId)
|
263 |
-
|
264 |
console.log('updateNodeAndSelect', nodeId, entityId, propertyName, newValue)
|
265 |
-
|
266 |
// For entity_id changes (node renaming) with NetworkX graph storage
|
267 |
if ((nodeId === entityId) && (propertyName === 'entity_id')) {
|
268 |
// Create new node with updated ID but same attributes
|
269 |
sigmaGraph.addNode(newValue, { ...nodeAttributes, label: newValue })
|
270 |
-
|
271 |
const edgesToUpdate: EdgeToUpdate[] = []
|
272 |
-
|
273 |
// Process all edges connected to this node
|
274 |
sigmaGraph.forEachEdge(nodeId, (edge, attributes, source, target) => {
|
275 |
const otherNode = source === nodeId ? target : source
|
276 |
const isOutgoing = source === nodeId
|
277 |
-
|
278 |
// Get original edge dynamic ID for later reference
|
279 |
const originalEdgeDynamicId = edge
|
280 |
const edgeIndexInRawGraph = rawGraph.edgeDynamicIdMap[originalEdgeDynamicId]
|
281 |
-
|
282 |
// Create new edge with updated node reference
|
283 |
const newEdgeId = sigmaGraph.addEdge(
|
284 |
isOutgoing ? newValue : otherNode,
|
285 |
isOutgoing ? otherNode : newValue,
|
286 |
attributes
|
287 |
)
|
288 |
-
|
289 |
// Track edges that need updating in the raw graph
|
290 |
if (edgeIndexInRawGraph !== undefined) {
|
291 |
edgesToUpdate.push({
|
@@ -294,14 +294,14 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
|
294 |
edgeIndex: edgeIndexInRawGraph
|
295 |
})
|
296 |
}
|
297 |
-
|
298 |
// Remove the old edge
|
299 |
sigmaGraph.dropEdge(edge)
|
300 |
})
|
301 |
-
|
302 |
// Remove the old node after all edges are processed
|
303 |
sigmaGraph.dropNode(nodeId)
|
304 |
-
|
305 |
// Update node reference in raw graph data
|
306 |
const nodeIndex = rawGraph.nodeIdMap[nodeId]
|
307 |
if (nodeIndex !== undefined) {
|
@@ -311,7 +311,7 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
|
311 |
delete rawGraph.nodeIdMap[nodeId]
|
312 |
rawGraph.nodeIdMap[newValue] = nodeIndex
|
313 |
}
|
314 |
-
|
315 |
// Update all edge references in raw graph data
|
316 |
edgesToUpdate.forEach(({ originalDynamicId, newEdgeId, edgeIndex }) => {
|
317 |
if (rawGraph.edges[edgeIndex]) {
|
@@ -322,14 +322,14 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
|
322 |
if (rawGraph.edges[edgeIndex].target === nodeId) {
|
323 |
rawGraph.edges[edgeIndex].target = newValue
|
324 |
}
|
325 |
-
|
326 |
// Update dynamic ID mappings
|
327 |
rawGraph.edges[edgeIndex].dynamicId = newEdgeId
|
328 |
delete rawGraph.edgeDynamicIdMap[originalDynamicId]
|
329 |
rawGraph.edgeDynamicIdMap[newEdgeId] = edgeIndex
|
330 |
}
|
331 |
})
|
332 |
-
|
333 |
// Update selected node in store
|
334 |
set({ selectedNode: newValue, moveToSelectedNode: true })
|
335 |
} else {
|
@@ -342,7 +342,7 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
|
342 |
sigmaGraph.setNodeAttribute(String(nodeId), 'label', newValue)
|
343 |
}
|
344 |
}
|
345 |
-
|
346 |
// Trigger a re-render by incrementing the version counter
|
347 |
set((state) => ({ graphDataVersion: state.graphDataVersion + 1 }))
|
348 |
}
|
@@ -351,17 +351,17 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
|
351 |
throw new Error('Failed to update node in graph')
|
352 |
}
|
353 |
},
|
354 |
-
|
355 |
updateEdgeAndSelect: async (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => {
|
356 |
// Get current state
|
357 |
const state = get()
|
358 |
const { sigmaGraph, rawGraph } = state
|
359 |
-
|
360 |
// Validate graph state
|
361 |
if (!sigmaGraph || !rawGraph) {
|
362 |
return
|
363 |
}
|
364 |
-
|
365 |
try {
|
366 |
const edgeIndex = rawGraph.edgeIdMap[String(edgeId)]
|
367 |
if (edgeIndex !== undefined && rawGraph.edges[edgeIndex]) {
|
@@ -370,10 +370,10 @@ const useGraphStoreBase = create<GraphState>()((set, get) => ({
|
|
370 |
sigmaGraph.setEdgeAttribute(dynamicId, 'label', newValue)
|
371 |
}
|
372 |
}
|
373 |
-
|
374 |
// Trigger a re-render by incrementing the version counter
|
375 |
set((state) => ({ graphDataVersion: state.graphDataVersion + 1 }))
|
376 |
-
|
377 |
// Update selected edge in store to ensure UI reflects changes
|
378 |
set({ selectedEdge: dynamicId })
|
379 |
} catch (error) {
|
|
|
136 |
// Version counter to trigger data refresh
|
137 |
graphDataVersion: number
|
138 |
incrementGraphDataVersion: () => void
|
139 |
+
|
140 |
// Methods for updating graph elements and UI state together
|
141 |
updateNodeAndSelect: (nodeId: string, entityId: string, propertyName: string, newValue: string) => Promise<void>
|
142 |
updateEdgeAndSelect: (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => Promise<void>
|
|
|
252 |
// Get current state
|
253 |
const state = get()
|
254 |
const { sigmaGraph, rawGraph } = state
|
255 |
+
|
256 |
// Validate graph state
|
257 |
if (!sigmaGraph || !rawGraph || !sigmaGraph.hasNode(nodeId)) {
|
258 |
return
|
259 |
}
|
260 |
+
|
261 |
try {
|
262 |
const nodeAttributes = sigmaGraph.getNodeAttributes(nodeId)
|
263 |
+
|
264 |
console.log('updateNodeAndSelect', nodeId, entityId, propertyName, newValue)
|
265 |
+
|
266 |
// For entity_id changes (node renaming) with NetworkX graph storage
|
267 |
if ((nodeId === entityId) && (propertyName === 'entity_id')) {
|
268 |
// Create new node with updated ID but same attributes
|
269 |
sigmaGraph.addNode(newValue, { ...nodeAttributes, label: newValue })
|
270 |
+
|
271 |
const edgesToUpdate: EdgeToUpdate[] = []
|
272 |
+
|
273 |
// Process all edges connected to this node
|
274 |
sigmaGraph.forEachEdge(nodeId, (edge, attributes, source, target) => {
|
275 |
const otherNode = source === nodeId ? target : source
|
276 |
const isOutgoing = source === nodeId
|
277 |
+
|
278 |
// Get original edge dynamic ID for later reference
|
279 |
const originalEdgeDynamicId = edge
|
280 |
const edgeIndexInRawGraph = rawGraph.edgeDynamicIdMap[originalEdgeDynamicId]
|
281 |
+
|
282 |
// Create new edge with updated node reference
|
283 |
const newEdgeId = sigmaGraph.addEdge(
|
284 |
isOutgoing ? newValue : otherNode,
|
285 |
isOutgoing ? otherNode : newValue,
|
286 |
attributes
|
287 |
)
|
288 |
+
|
289 |
// Track edges that need updating in the raw graph
|
290 |
if (edgeIndexInRawGraph !== undefined) {
|
291 |
edgesToUpdate.push({
|
|
|
294 |
edgeIndex: edgeIndexInRawGraph
|
295 |
})
|
296 |
}
|
297 |
+
|
298 |
// Remove the old edge
|
299 |
sigmaGraph.dropEdge(edge)
|
300 |
})
|
301 |
+
|
302 |
// Remove the old node after all edges are processed
|
303 |
sigmaGraph.dropNode(nodeId)
|
304 |
+
|
305 |
// Update node reference in raw graph data
|
306 |
const nodeIndex = rawGraph.nodeIdMap[nodeId]
|
307 |
if (nodeIndex !== undefined) {
|
|
|
311 |
delete rawGraph.nodeIdMap[nodeId]
|
312 |
rawGraph.nodeIdMap[newValue] = nodeIndex
|
313 |
}
|
314 |
+
|
315 |
// Update all edge references in raw graph data
|
316 |
edgesToUpdate.forEach(({ originalDynamicId, newEdgeId, edgeIndex }) => {
|
317 |
if (rawGraph.edges[edgeIndex]) {
|
|
|
322 |
if (rawGraph.edges[edgeIndex].target === nodeId) {
|
323 |
rawGraph.edges[edgeIndex].target = newValue
|
324 |
}
|
325 |
+
|
326 |
// Update dynamic ID mappings
|
327 |
rawGraph.edges[edgeIndex].dynamicId = newEdgeId
|
328 |
delete rawGraph.edgeDynamicIdMap[originalDynamicId]
|
329 |
rawGraph.edgeDynamicIdMap[newEdgeId] = edgeIndex
|
330 |
}
|
331 |
})
|
332 |
+
|
333 |
// Update selected node in store
|
334 |
set({ selectedNode: newValue, moveToSelectedNode: true })
|
335 |
} else {
|
|
|
342 |
sigmaGraph.setNodeAttribute(String(nodeId), 'label', newValue)
|
343 |
}
|
344 |
}
|
345 |
+
|
346 |
// Trigger a re-render by incrementing the version counter
|
347 |
set((state) => ({ graphDataVersion: state.graphDataVersion + 1 }))
|
348 |
}
|
|
|
351 |
throw new Error('Failed to update node in graph')
|
352 |
}
|
353 |
},
|
354 |
+
|
355 |
updateEdgeAndSelect: async (edgeId: string, dynamicId: string, sourceId: string, targetId: string, propertyName: string, newValue: string) => {
|
356 |
// Get current state
|
357 |
const state = get()
|
358 |
const { sigmaGraph, rawGraph } = state
|
359 |
+
|
360 |
// Validate graph state
|
361 |
if (!sigmaGraph || !rawGraph) {
|
362 |
return
|
363 |
}
|
364 |
+
|
365 |
try {
|
366 |
const edgeIndex = rawGraph.edgeIdMap[String(edgeId)]
|
367 |
if (edgeIndex !== undefined && rawGraph.edges[edgeIndex]) {
|
|
|
370 |
sigmaGraph.setEdgeAttribute(dynamicId, 'label', newValue)
|
371 |
}
|
372 |
}
|
373 |
+
|
374 |
// Trigger a re-render by incrementing the version counter
|
375 |
set((state) => ({ graphDataVersion: state.graphDataVersion + 1 }))
|
376 |
+
|
377 |
// Update selected edge in store to ensure UI reflects changes
|
378 |
set({ selectedEdge: dynamicId })
|
379 |
} catch (error) {
|
lightrag_webui/src/utils/graphOperations.ts
CHANGED
@@ -3,7 +3,7 @@ import { useGraphStore } from '@/stores/graph'
|
|
3 |
/**
|
4 |
* Update node in the graph visualization
|
5 |
* This function is now a wrapper around the store's updateNodeAndSelect method
|
6 |
-
*
|
7 |
* @param nodeId - ID of the node to update
|
8 |
* @param entityId - ID of the entity
|
9 |
* @param propertyName - Name of the property being updated
|
|
|
3 |
/**
|
4 |
* Update node in the graph visualization
|
5 |
* This function is now a wrapper around the store's updateNodeAndSelect method
|
6 |
+
*
|
7 |
* @param nodeId - ID of the node to update
|
8 |
* @param entityId - ID of the entity
|
9 |
* @param propertyName - Name of the property being updated
|
tests/test_graph_storage.py
CHANGED
@@ -510,35 +510,66 @@ async def test_graph_batch_operations(storage):
|
|
510 |
assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
|
511 |
assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
|
512 |
assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
|
513 |
-
assert
|
514 |
-
|
515 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
|
517 |
# 3. 测试 node_degrees_batch - 批量获取多个节点的度数
|
518 |
print("== 测试 node_degrees_batch")
|
519 |
node_degrees = await storage.node_degrees_batch(node_ids)
|
520 |
print(f"批量获取节点度数结果: {node_degrees}")
|
521 |
-
assert
|
|
|
|
|
522 |
assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
|
523 |
assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
|
524 |
assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
|
525 |
-
assert
|
526 |
-
|
527 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
|
529 |
# 4. 测试 edge_degrees_batch - 批量获取多个边的度数
|
530 |
print("== 测试 edge_degrees_batch")
|
531 |
edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
|
532 |
edge_degrees = await storage.edge_degrees_batch(edges)
|
533 |
print(f"批量获取边度数结果: {edge_degrees}")
|
534 |
-
assert
|
535 |
-
|
536 |
-
|
537 |
-
assert (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
538 |
# 验证边的度数是否正确(源节点度数 + 目标节点度数)
|
539 |
-
assert
|
540 |
-
|
541 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
542 |
|
543 |
# 5. 测试 get_edges_batch - 批量获取多个边的属性
|
544 |
print("== 测试 get_edges_batch")
|
@@ -547,28 +578,54 @@ async def test_graph_batch_operations(storage):
|
|
547 |
edges_dict = await storage.get_edges_batch(edge_dicts)
|
548 |
print(f"批量获取边属性结果: {edges_dict.keys()}")
|
549 |
assert len(edges_dict) == 3, f"应返回3条边的属性,实际返回 {len(edges_dict)} 条"
|
550 |
-
assert (
|
551 |
-
|
552 |
-
|
553 |
-
|
554 |
-
assert
|
555 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
556 |
|
557 |
# 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
|
558 |
print("== 测试 get_nodes_edges_batch")
|
559 |
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
|
560 |
print(f"批量获取节点边结果: {nodes_edges.keys()}")
|
561 |
-
assert
|
|
|
|
|
562 |
assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
|
563 |
assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
|
564 |
-
assert
|
565 |
-
|
|
|
|
|
|
|
|
|
566 |
|
567 |
# 7. 清理数据
|
568 |
print("== 测试 drop")
|
569 |
result = await storage.drop()
|
570 |
print(f"清理结果: {result}")
|
571 |
-
assert
|
|
|
|
|
572 |
|
573 |
print("\n批量操作测试完成")
|
574 |
return True
|
@@ -630,7 +687,7 @@ async def main():
|
|
630 |
if basic_result:
|
631 |
ASCIIColors.cyan("\n=== 开始高级测试 ===")
|
632 |
advanced_result = await test_graph_advanced(storage)
|
633 |
-
|
634 |
if advanced_result:
|
635 |
ASCIIColors.cyan("\n=== 开始批量操作测试 ===")
|
636 |
await test_graph_batch_operations(storage)
|
|
|
510 |
assert node1_id in nodes_dict, f"{node1_id} 应在返回结果中"
|
511 |
assert node2_id in nodes_dict, f"{node2_id} 应在返回结果中"
|
512 |
assert node3_id in nodes_dict, f"{node3_id} 应在返回结果中"
|
513 |
+
assert (
|
514 |
+
nodes_dict[node1_id]["description"] == node1_data["description"]
|
515 |
+
), f"{node1_id} 描述不匹配"
|
516 |
+
assert (
|
517 |
+
nodes_dict[node2_id]["description"] == node2_data["description"]
|
518 |
+
), f"{node2_id} 描述不匹配"
|
519 |
+
assert (
|
520 |
+
nodes_dict[node3_id]["description"] == node3_data["description"]
|
521 |
+
), f"{node3_id} 描述不匹配"
|
522 |
|
523 |
# 3. 测试 node_degrees_batch - 批量获取多个节点的度数
|
524 |
print("== 测试 node_degrees_batch")
|
525 |
node_degrees = await storage.node_degrees_batch(node_ids)
|
526 |
print(f"批量获取节点度数结果: {node_degrees}")
|
527 |
+
assert (
|
528 |
+
len(node_degrees) == 3
|
529 |
+
), f"应返回3个节点的度数,实际返回 {len(node_degrees)} 个"
|
530 |
assert node1_id in node_degrees, f"{node1_id} 应在返回结果中"
|
531 |
assert node2_id in node_degrees, f"{node2_id} 应在返回结果中"
|
532 |
assert node3_id in node_degrees, f"{node3_id} 应在返回结果中"
|
533 |
+
assert (
|
534 |
+
node_degrees[node1_id] == 3
|
535 |
+
), f"{node1_id} 度数应为3,实际为 {node_degrees[node1_id]}"
|
536 |
+
assert (
|
537 |
+
node_degrees[node2_id] == 2
|
538 |
+
), f"{node2_id} 度数应为2,实际为 {node_degrees[node2_id]}"
|
539 |
+
assert (
|
540 |
+
node_degrees[node3_id] == 3
|
541 |
+
), f"{node3_id} 度数应为3,实际为 {node_degrees[node3_id]}"
|
542 |
|
543 |
# 4. 测试 edge_degrees_batch - 批量获取多个边的度数
|
544 |
print("== 测试 edge_degrees_batch")
|
545 |
edges = [(node1_id, node2_id), (node2_id, node3_id), (node3_id, node4_id)]
|
546 |
edge_degrees = await storage.edge_degrees_batch(edges)
|
547 |
print(f"批量获取边度数结果: {edge_degrees}")
|
548 |
+
assert (
|
549 |
+
len(edge_degrees) == 3
|
550 |
+
), f"应返回3条边的度数,实际返回 {len(edge_degrees)} 条"
|
551 |
+
assert (
|
552 |
+
node1_id,
|
553 |
+
node2_id,
|
554 |
+
) in edge_degrees, f"边 {node1_id} -> {node2_id} 应在返回结果中"
|
555 |
+
assert (
|
556 |
+
node2_id,
|
557 |
+
node3_id,
|
558 |
+
) in edge_degrees, f"边 {node2_id} -> {node3_id} 应在返回结果中"
|
559 |
+
assert (
|
560 |
+
node3_id,
|
561 |
+
node4_id,
|
562 |
+
) in edge_degrees, f"边 {node3_id} -> {node4_id} 应在返回结果中"
|
563 |
# 验证边的度数是否正确(源节点度数 + 目标节点度数)
|
564 |
+
assert (
|
565 |
+
edge_degrees[(node1_id, node2_id)] == 5
|
566 |
+
), f"边 {node1_id} -> {node2_id} 度数应为5,实际为 {edge_degrees[(node1_id, node2_id)]}"
|
567 |
+
assert (
|
568 |
+
edge_degrees[(node2_id, node3_id)] == 5
|
569 |
+
), f"边 {node2_id} -> {node3_id} 度数应为5,实际为 {edge_degrees[(node2_id, node3_id)]}"
|
570 |
+
assert (
|
571 |
+
edge_degrees[(node3_id, node4_id)] == 5
|
572 |
+
), f"边 {node3_id} -> {node4_id} 度数应为5,实际为 {edge_degrees[(node3_id, node4_id)]}"
|
573 |
|
574 |
# 5. 测试 get_edges_batch - 批量获取多个边的属性
|
575 |
print("== 测试 get_edges_batch")
|
|
|
578 |
edges_dict = await storage.get_edges_batch(edge_dicts)
|
579 |
print(f"批量获取边属性结果: {edges_dict.keys()}")
|
580 |
assert len(edges_dict) == 3, f"应返回3条边的属性,实际返回 {len(edges_dict)} 条"
|
581 |
+
assert (
|
582 |
+
node1_id,
|
583 |
+
node2_id,
|
584 |
+
) in edges_dict, f"边 {node1_id} -> {node2_id} 应在返回结果中"
|
585 |
+
assert (
|
586 |
+
node2_id,
|
587 |
+
node3_id,
|
588 |
+
) in edges_dict, f"边 {node2_id} -> {node3_id} 应在返回结果中"
|
589 |
+
assert (
|
590 |
+
node3_id,
|
591 |
+
node4_id,
|
592 |
+
) in edges_dict, f"边 {node3_id} -> {node4_id} 应在返回结果中"
|
593 |
+
assert (
|
594 |
+
edges_dict[(node1_id, node2_id)]["relationship"]
|
595 |
+
== edge1_data["relationship"]
|
596 |
+
), f"边 {node1_id} -> {node2_id} 关系不匹配"
|
597 |
+
assert (
|
598 |
+
edges_dict[(node2_id, node3_id)]["relationship"]
|
599 |
+
== edge2_data["relationship"]
|
600 |
+
), f"边 {node2_id} -> {node3_id} 关系不匹配"
|
601 |
+
assert (
|
602 |
+
edges_dict[(node3_id, node4_id)]["relationship"]
|
603 |
+
== edge5_data["relationship"]
|
604 |
+
), f"边 {node3_id} -> {node4_id} 关系不匹配"
|
605 |
|
606 |
# 6. 测试 get_nodes_edges_batch - 批量获取多个节点的所有边
|
607 |
print("== 测试 get_nodes_edges_batch")
|
608 |
nodes_edges = await storage.get_nodes_edges_batch([node1_id, node3_id])
|
609 |
print(f"批量获取节点边结果: {nodes_edges.keys()}")
|
610 |
+
assert (
|
611 |
+
len(nodes_edges) == 2
|
612 |
+
), f"应返回2个节点的边,实际返回 {len(nodes_edges)} 个"
|
613 |
assert node1_id in nodes_edges, f"{node1_id} 应在返回结果中"
|
614 |
assert node3_id in nodes_edges, f"{node3_id} 应在返回结果中"
|
615 |
+
assert (
|
616 |
+
len(nodes_edges[node1_id]) == 3
|
617 |
+
), f"{node1_id} 应有3条边,实际有 {len(nodes_edges[node1_id])} 条"
|
618 |
+
assert (
|
619 |
+
len(nodes_edges[node3_id]) == 3
|
620 |
+
), f"{node3_id} 应有3条边,实际有 {len(nodes_edges[node3_id])} 条"
|
621 |
|
622 |
# 7. 清理数据
|
623 |
print("== 测试 drop")
|
624 |
result = await storage.drop()
|
625 |
print(f"清理结果: {result}")
|
626 |
+
assert (
|
627 |
+
result["status"] == "success"
|
628 |
+
), f"清理应成功,实际状态为 {result['status']}"
|
629 |
|
630 |
print("\n批量操作测试完成")
|
631 |
return True
|
|
|
687 |
if basic_result:
|
688 |
ASCIIColors.cyan("\n=== 开始高级测试 ===")
|
689 |
advanced_result = await test_graph_advanced(storage)
|
690 |
+
|
691 |
if advanced_result:
|
692 |
ASCIIColors.cyan("\n=== 开始批量操作测试 ===")
|
693 |
await test_graph_batch_operations(storage)
|