YanSte commited on
Commit
4aa228f
·
unverified ·
2 Parent(s): eaa50fc 60dd542

Merge pull request #901 from HKUDS/revert-886-clean-2

Browse files
lightrag/api/lightrag_server.py CHANGED
@@ -1683,6 +1683,10 @@ def create_app(args):
1683
  raise HTTPException(status_code=500, detail=str(e))
1684
 
1685
  # query all graph
 
 
 
 
1686
  # Add Ollama API routes
1687
  ollama_api = OllamaAPI(rag, top_k=args.top_k)
1688
  app.include_router(ollama_api.router, prefix="/api")
 
1683
  raise HTTPException(status_code=500, detail=str(e))
1684
 
1685
  # query all graph
1686
+ @app.get("/graphs")
1687
+ async def get_knowledge_graph(label: str):
1688
+ return await rag.get_knowledge_graph(nodel_label=label, max_depth=100)
1689
+
1690
  # Add Ollama API routes
1691
  ollama_api = OllamaAPI(rag, top_k=args.top_k)
1692
  app.include_router(ollama_api.router, prefix="/api")
lightrag/base.py CHANGED
@@ -13,6 +13,7 @@ from typing import (
13
  )
14
  import numpy as np
15
  from .utils import EmbeddingFunc
 
16
 
17
  load_dotenv()
18
 
@@ -197,6 +198,12 @@ class BaseGraphStorage(StorageNameSpace, ABC):
197
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
198
  """Get all labels in the graph."""
199
 
 
 
 
 
 
 
200
 
201
  class DocStatus(str, Enum):
202
  """Document processing status"""
 
13
  )
14
  import numpy as np
15
  from .utils import EmbeddingFunc
16
+ from .types import KnowledgeGraph
17
 
18
  load_dotenv()
19
 
 
198
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
199
  """Get all labels in the graph."""
200
 
201
+ @abstractmethod
202
+ async def get_knowledge_graph(
203
+ self, node_label: str, max_depth: int = 5
204
+ ) -> KnowledgeGraph:
205
+ """Retrieve a subgraph of the knowledge graph starting from a given node."""
206
+
207
 
208
  class DocStatus(str, Enum):
209
  """Document processing status"""
lightrag/kg/age_impl.py CHANGED
@@ -8,6 +8,7 @@ from dataclasses import dataclass
8
  from typing import Any, Dict, List, NamedTuple, Optional, Union, final
9
  import numpy as np
10
  import pipmaster as pm
 
11
 
12
  from tenacity import (
13
  retry,
@@ -615,6 +616,11 @@ class AGEStorage(BaseGraphStorage):
615
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
616
  raise NotImplementedError
617
 
 
 
 
 
 
618
  async def index_done_callback(self) -> None:
619
  # AGES handles persistence automatically
620
  pass
 
8
  from typing import Any, Dict, List, NamedTuple, Optional, Union, final
9
  import numpy as np
10
  import pipmaster as pm
11
+ from lightrag.types import KnowledgeGraph
12
 
13
  from tenacity import (
14
  retry,
 
616
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
617
  raise NotImplementedError
618
 
619
+ async def get_knowledge_graph(
620
+ self, node_label: str, max_depth: int = 5
621
+ ) -> KnowledgeGraph:
622
+ raise NotImplementedError
623
+
624
  async def index_done_callback(self) -> None:
625
  # AGES handles persistence automatically
626
  pass
lightrag/kg/gremlin_impl.py CHANGED
@@ -16,6 +16,7 @@ from tenacity import (
16
  wait_exponential,
17
  )
18
 
 
19
  from lightrag.utils import logger
20
 
21
  from ..base import BaseGraphStorage
@@ -401,3 +402,8 @@ class GremlinStorage(BaseGraphStorage):
401
  self, algorithm: str
402
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
403
  raise NotImplementedError
 
 
 
 
 
 
16
  wait_exponential,
17
  )
18
 
19
+ from lightrag.types import KnowledgeGraph
20
  from lightrag.utils import logger
21
 
22
  from ..base import BaseGraphStorage
 
402
  self, algorithm: str
403
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
404
  raise NotImplementedError
405
+
406
+ async def get_knowledge_graph(
407
+ self, node_label: str, max_depth: int = 5
408
+ ) -> KnowledgeGraph:
409
+ raise NotImplementedError
lightrag/kg/mongo_impl.py CHANGED
@@ -16,6 +16,7 @@ from ..base import (
16
  )
17
  from ..namespace import NameSpace, is_namespace
18
  from ..utils import logger
 
19
  import pipmaster as pm
20
 
21
  if not pm.is_installed("pymongo"):
@@ -598,6 +599,179 @@ class MongoGraphStorage(BaseGraphStorage):
598
  # -------------------------------------------------------------------------
599
  # QUERY
600
  # -------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
601
 
602
  async def index_done_callback(self) -> None:
603
  # Mongo handles persistence automatically
 
16
  )
17
  from ..namespace import NameSpace, is_namespace
18
  from ..utils import logger
19
+ from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
20
  import pipmaster as pm
21
 
22
  if not pm.is_installed("pymongo"):
 
599
  # -------------------------------------------------------------------------
600
  # QUERY
601
  # -------------------------------------------------------------------------
602
+ #
603
+
604
+ async def get_knowledge_graph(
605
+ self, node_label: str, max_depth: int = 5
606
+ ) -> KnowledgeGraph:
607
+ """
608
+ Get complete connected subgraph for specified node (including the starting node itself)
609
+
610
+ Args:
611
+ node_label: Label of the nodes to start from
612
+ max_depth: Maximum depth of traversal (default: 5)
613
+
614
+ Returns:
615
+ KnowledgeGraph object containing nodes and edges of the subgraph
616
+ """
617
+ label = node_label
618
+ result = KnowledgeGraph()
619
+ seen_nodes = set()
620
+ seen_edges = set()
621
+
622
+ try:
623
+ if label == "*":
624
+ # Get all nodes and edges
625
+ async for node_doc in self.collection.find({}):
626
+ node_id = str(node_doc["_id"])
627
+ if node_id not in seen_nodes:
628
+ result.nodes.append(
629
+ KnowledgeGraphNode(
630
+ id=node_id,
631
+ labels=[node_doc.get("_id")],
632
+ properties={
633
+ k: v
634
+ for k, v in node_doc.items()
635
+ if k not in ["_id", "edges"]
636
+ },
637
+ )
638
+ )
639
+ seen_nodes.add(node_id)
640
+
641
+ # Process edges
642
+ for edge in node_doc.get("edges", []):
643
+ edge_id = f"{node_id}-{edge['target']}"
644
+ if edge_id not in seen_edges:
645
+ result.edges.append(
646
+ KnowledgeGraphEdge(
647
+ id=edge_id,
648
+ type=edge.get("relation", ""),
649
+ source=node_id,
650
+ target=edge["target"],
651
+ properties={
652
+ k: v
653
+ for k, v in edge.items()
654
+ if k not in ["target", "relation"]
655
+ },
656
+ )
657
+ )
658
+ seen_edges.add(edge_id)
659
+ else:
660
+ # Verify if starting node exists
661
+ start_nodes = self.collection.find({"_id": label})
662
+ start_nodes_exist = await start_nodes.to_list(length=1)
663
+ if not start_nodes_exist:
664
+ logger.warning(f"Starting node with label {label} does not exist!")
665
+ return result
666
+
667
+ # Use $graphLookup for traversal
668
+ pipeline = [
669
+ {
670
+ "$match": {"_id": label}
671
+ }, # Start with nodes having the specified label
672
+ {
673
+ "$graphLookup": {
674
+ "from": self._collection_name,
675
+ "startWith": "$edges.target",
676
+ "connectFromField": "edges.target",
677
+ "connectToField": "_id",
678
+ "maxDepth": max_depth,
679
+ "depthField": "depth",
680
+ "as": "connected_nodes",
681
+ }
682
+ },
683
+ ]
684
+
685
+ async for doc in self.collection.aggregate(pipeline):
686
+ # Add the start node
687
+ node_id = str(doc["_id"])
688
+ if node_id not in seen_nodes:
689
+ result.nodes.append(
690
+ KnowledgeGraphNode(
691
+ id=node_id,
692
+ labels=[
693
+ doc.get(
694
+ "_id",
695
+ )
696
+ ],
697
+ properties={
698
+ k: v
699
+ for k, v in doc.items()
700
+ if k
701
+ not in [
702
+ "_id",
703
+ "edges",
704
+ "connected_nodes",
705
+ "depth",
706
+ ]
707
+ },
708
+ )
709
+ )
710
+ seen_nodes.add(node_id)
711
+
712
+ # Add edges from start node
713
+ for edge in doc.get("edges", []):
714
+ edge_id = f"{node_id}-{edge['target']}"
715
+ if edge_id not in seen_edges:
716
+ result.edges.append(
717
+ KnowledgeGraphEdge(
718
+ id=edge_id,
719
+ type=edge.get("relation", ""),
720
+ source=node_id,
721
+ target=edge["target"],
722
+ properties={
723
+ k: v
724
+ for k, v in edge.items()
725
+ if k not in ["target", "relation"]
726
+ },
727
+ )
728
+ )
729
+ seen_edges.add(edge_id)
730
+
731
+ # Add connected nodes and their edges
732
+ for connected in doc.get("connected_nodes", []):
733
+ node_id = str(connected["_id"])
734
+ if node_id not in seen_nodes:
735
+ result.nodes.append(
736
+ KnowledgeGraphNode(
737
+ id=node_id,
738
+ labels=[connected.get("_id")],
739
+ properties={
740
+ k: v
741
+ for k, v in connected.items()
742
+ if k not in ["_id", "edges", "depth"]
743
+ },
744
+ )
745
+ )
746
+ seen_nodes.add(node_id)
747
+
748
+ # Add edges from connected nodes
749
+ for edge in connected.get("edges", []):
750
+ edge_id = f"{node_id}-{edge['target']}"
751
+ if edge_id not in seen_edges:
752
+ result.edges.append(
753
+ KnowledgeGraphEdge(
754
+ id=edge_id,
755
+ type=edge.get("relation", ""),
756
+ source=node_id,
757
+ target=edge["target"],
758
+ properties={
759
+ k: v
760
+ for k, v in edge.items()
761
+ if k not in ["target", "relation"]
762
+ },
763
+ )
764
+ )
765
+ seen_edges.add(edge_id)
766
+
767
+ logger.info(
768
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
769
+ )
770
+
771
+ except PyMongoError as e:
772
+ logger.error(f"MongoDB query failed: {str(e)}")
773
+
774
+ return result
775
 
776
  async def index_done_callback(self) -> None:
777
  # Mongo handles persistence automatically
lightrag/kg/neo4j_impl.py CHANGED
@@ -17,6 +17,7 @@ from tenacity import (
17
 
18
  from ..utils import logger
19
  from ..base import BaseGraphStorage
 
20
  import pipmaster as pm
21
 
22
  if not pm.is_installed("neo4j"):
@@ -468,6 +469,99 @@ class Neo4JStorage(BaseGraphStorage):
468
  async def _node2vec_embed(self):
469
  print("Implemented but never called.")
470
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
471
  async def _robust_fallback(
472
  self, label: str, max_depth: int
473
  ) -> Dict[str, List[Dict]]:
 
17
 
18
  from ..utils import logger
19
  from ..base import BaseGraphStorage
20
+ from ..types import KnowledgeGraph, KnowledgeGraphNode, KnowledgeGraphEdge
21
  import pipmaster as pm
22
 
23
  if not pm.is_installed("neo4j"):
 
469
  async def _node2vec_embed(self):
470
  print("Implemented but never called.")
471
 
472
+ async def get_knowledge_graph(
473
+ self, node_label: str, max_depth: int = 5
474
+ ) -> KnowledgeGraph:
475
+ """
476
+ Get complete connected subgraph for specified node (including the starting node itself)
477
+
478
+ Key fixes:
479
+ 1. Include the starting node itself
480
+ 2. Handle multi-label nodes
481
+ 3. Clarify relationship directions
482
+ 4. Add depth control
483
+ """
484
+ label = node_label.strip('"')
485
+ result = KnowledgeGraph()
486
+ seen_nodes = set()
487
+ seen_edges = set()
488
+
489
+ async with self._driver.session(database=self._DATABASE) as session:
490
+ try:
491
+ main_query = ""
492
+ if label == "*":
493
+ main_query = """
494
+ MATCH (n)
495
+ WITH collect(DISTINCT n) AS nodes
496
+ MATCH ()-[r]-()
497
+ RETURN nodes, collect(DISTINCT r) AS relationships;
498
+ """
499
+ else:
500
+ # Critical debug step: first verify if starting node exists
501
+ validate_query = f"MATCH (n:`{label}`) RETURN n LIMIT 1"
502
+ validate_result = await session.run(validate_query)
503
+ if not await validate_result.single():
504
+ logger.warning(f"Starting node {label} does not exist!")
505
+ return result
506
+
507
+ # Optimized query (including direction handling and self-loops)
508
+ main_query = f"""
509
+ MATCH (start:`{label}`)
510
+ WITH start
511
+ CALL apoc.path.subgraphAll(start, {{
512
+ relationshipFilter: '>',
513
+ minLevel: 0,
514
+ maxLevel: {max_depth},
515
+ bfs: true
516
+ }})
517
+ YIELD nodes, relationships
518
+ RETURN nodes, relationships
519
+ """
520
+ result_set = await session.run(main_query)
521
+ record = await result_set.single()
522
+
523
+ if record:
524
+ # Handle nodes (compatible with multi-label cases)
525
+ for node in record["nodes"]:
526
+ # Use node ID + label combination as unique identifier
527
+ node_id = node.id
528
+ if node_id not in seen_nodes:
529
+ result.nodes.append(
530
+ KnowledgeGraphNode(
531
+ id=f"{node_id}",
532
+ labels=list(node.labels),
533
+ properties=dict(node),
534
+ )
535
+ )
536
+ seen_nodes.add(node_id)
537
+
538
+ # Handle relationships (including direction information)
539
+ for rel in record["relationships"]:
540
+ edge_id = rel.id
541
+ if edge_id not in seen_edges:
542
+ start = rel.start_node
543
+ end = rel.end_node
544
+ result.edges.append(
545
+ KnowledgeGraphEdge(
546
+ id=f"{edge_id}",
547
+ type=rel.type,
548
+ source=f"{start.id}",
549
+ target=f"{end.id}",
550
+ properties=dict(rel),
551
+ )
552
+ )
553
+ seen_edges.add(edge_id)
554
+
555
+ logger.info(
556
+ f"Subgraph query successful | Node count: {len(result.nodes)} | Edge count: {len(result.edges)}"
557
+ )
558
+
559
+ except neo4jExceptions.ClientError as e:
560
+ logger.error(f"APOC query failed: {str(e)}")
561
+ return await self._robust_fallback(label, max_depth)
562
+
563
+ return result
564
+
565
  async def _robust_fallback(
566
  self, label: str, max_depth: int
567
  ) -> Dict[str, List[Dict]]:
lightrag/kg/networkx_impl.py CHANGED
@@ -5,6 +5,7 @@ from typing import Any, final
5
  import numpy as np
6
 
7
 
 
8
  from lightrag.utils import (
9
  logger,
10
  )
@@ -166,3 +167,8 @@ class NetworkXStorage(BaseGraphStorage):
166
  for source, target in edges:
167
  if self._graph.has_edge(source, target):
168
  self._graph.remove_edge(source, target)
 
 
 
 
 
 
5
  import numpy as np
6
 
7
 
8
+ from lightrag.types import KnowledgeGraph
9
  from lightrag.utils import (
10
  logger,
11
  )
 
167
  for source, target in edges:
168
  if self._graph.has_edge(source, target):
169
  self._graph.remove_edge(source, target)
170
+
171
+ async def get_knowledge_graph(
172
+ self, node_label: str, max_depth: int = 5
173
+ ) -> KnowledgeGraph:
174
+ raise NotImplementedError
lightrag/kg/oracle_impl.py CHANGED
@@ -8,6 +8,7 @@ from typing import Any, Union, final
8
  import numpy as np
9
  import configparser
10
 
 
11
 
12
  from ..base import (
13
  BaseGraphStorage,
@@ -669,6 +670,11 @@ class OracleGraphStorage(BaseGraphStorage):
669
  async def delete_node(self, node_id: str) -> None:
670
  raise NotImplementedError
671
 
 
 
 
 
 
672
 
673
  N_T = {
674
  NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
 
8
  import numpy as np
9
  import configparser
10
 
11
+ from lightrag.types import KnowledgeGraph
12
 
13
  from ..base import (
14
  BaseGraphStorage,
 
670
  async def delete_node(self, node_id: str) -> None:
671
  raise NotImplementedError
672
 
673
+ async def get_knowledge_graph(
674
+ self, node_label: str, max_depth: int = 5
675
+ ) -> KnowledgeGraph:
676
+ raise NotImplementedError
677
+
678
 
679
  N_T = {
680
  NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
lightrag/kg/postgres_impl.py CHANGED
@@ -7,6 +7,7 @@ from typing import Any, Union, final
7
  import numpy as np
8
  import configparser
9
 
 
10
 
11
  import sys
12
  from tenacity import (
@@ -1084,6 +1085,11 @@ class PGGraphStorage(BaseGraphStorage):
1084
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
1085
  raise NotImplementedError
1086
 
 
 
 
 
 
1087
  async def drop(self) -> None:
1088
  """Drop the storage"""
1089
  drop_sql = SQL_TEMPLATES["drop_vdb_entity"]
 
7
  import numpy as np
8
  import configparser
9
 
10
+ from lightrag.types import KnowledgeGraph
11
 
12
  import sys
13
  from tenacity import (
 
1085
  ) -> tuple[np.ndarray[Any, Any], list[str]]:
1086
  raise NotImplementedError
1087
 
1088
+ async def get_knowledge_graph(
1089
+ self, node_label: str, max_depth: int = 5
1090
+ ) -> KnowledgeGraph:
1091
+ raise NotImplementedError
1092
+
1093
  async def drop(self) -> None:
1094
  """Drop the storage"""
1095
  drop_sql = SQL_TEMPLATES["drop_vdb_entity"]
lightrag/kg/tidb_impl.py CHANGED
@@ -5,6 +5,8 @@ from typing import Any, Union, final
5
 
6
  import numpy as np
7
 
 
 
8
 
9
  from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
10
  from ..namespace import NameSpace, is_namespace
@@ -558,6 +560,11 @@ class TiDBGraphStorage(BaseGraphStorage):
558
  async def delete_node(self, node_id: str) -> None:
559
  raise NotImplementedError
560
 
 
 
 
 
 
561
 
562
  N_T = {
563
  NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
 
5
 
6
  import numpy as np
7
 
8
+ from lightrag.types import KnowledgeGraph
9
+
10
 
11
  from ..base import BaseGraphStorage, BaseKVStorage, BaseVectorStorage
12
  from ..namespace import NameSpace, is_namespace
 
560
  async def delete_node(self, node_id: str) -> None:
561
  raise NotImplementedError
562
 
563
+ async def get_knowledge_graph(
564
+ self, node_label: str, max_depth: int = 5
565
+ ) -> KnowledgeGraph:
566
+ raise NotImplementedError
567
+
568
 
569
  N_T = {
570
  NameSpace.KV_STORE_FULL_DOCS: "LIGHTRAG_DOC_FULL",
lightrag/lightrag.py CHANGED
@@ -47,6 +47,7 @@ from .utils import (
47
  set_logger,
48
  encode_string_by_tiktoken,
49
  )
 
50
 
51
  # TODO: TO REMOVE @Yannick
52
  config = configparser.ConfigParser()
@@ -457,6 +458,13 @@ class LightRAG:
457
  self._storages_status = StoragesStatus.FINALIZED
458
  logger.debug("Finalized Storages")
459
 
 
 
 
 
 
 
 
460
  def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
461
  import_path = STORAGES[storage_name]
462
  storage_class = lazy_external_import(import_path, storage_name)
 
47
  set_logger,
48
  encode_string_by_tiktoken,
49
  )
50
+ from .types import KnowledgeGraph
51
 
52
  # TODO: TO REMOVE @Yannick
53
  config = configparser.ConfigParser()
 
458
  self._storages_status = StoragesStatus.FINALIZED
459
  logger.debug("Finalized Storages")
460
 
461
+ async def get_knowledge_graph(
462
+ self, nodel_label: str, max_depth: int
463
+ ) -> KnowledgeGraph:
464
+ return await self.chunk_entity_relation_graph.get_knowledge_graph(
465
+ node_label=nodel_label, max_depth=max_depth
466
+ )
467
+
468
  def _get_storage_class(self, storage_name: str) -> Callable[..., Any]:
469
  import_path = STORAGES[storage_name]
470
  storage_class = lazy_external_import(import_path, storage_name)