samuel-z-chen commited on
Commit
90583bd
·
1 Parent(s): acb1edf

Revised the postgres implementation, to use attributes(node_id) rather than nodes to identify an entity. Which significantly reduced the table counts.

Browse files
Files changed (2) hide show
  1. README.md +5 -0
  2. lightrag/kg/postgres_impl.py +114 -156
README.md CHANGED
@@ -361,6 +361,11 @@ see test_neo4j.py for a working example.
361
  For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE).
362
  * PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac.
363
  * How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
 
 
 
 
 
364
 
365
  ### Insert Custom KG
366
 
 
361
  For production level scenarios you will most likely want to leverage an enterprise solution. PostgreSQL can provide a one-stop solution for you as KV store, VectorDB (pgvector) and GraphDB (apache AGE).
362
  * PostgreSQL is lightweight,the whole binary distribution including all necessary plugins can be zipped to 40MB: Ref to [Windows Release](https://github.com/ShanGor/apache-age-windows/releases/tag/PG17%2Fv1.5.0-rc0) as it is easy to install for Linux/Mac.
363
  * How to start? Ref to: [examples/lightrag_zhipu_postgres_demo.py](https://github.com/HKUDS/LightRAG/blob/main/examples/lightrag_zhipu_postgres_demo.py)
364
+ * Create index for AGE example: (Change below `dickens` to your graph name if necessary)
365
+ ```
366
+ SET search_path = ag_catalog, "$user", public;
367
+ CREATE INDEX idx_entity ON dickens."Entity" USING gin (agtype_access_operator(properties, '"node_id"'));
368
+ ```
369
 
370
  ### Insert Custom KG
371
 
lightrag/kg/postgres_impl.py CHANGED
@@ -130,6 +130,7 @@ class PostgreSQLDB:
130
  data: Union[list, dict] = None,
131
  for_age: bool = False,
132
  graph_name: str = None,
 
133
  ):
134
  try:
135
  async with self.pool.acquire() as connection:
@@ -140,6 +141,11 @@ class PostgreSQLDB:
140
  await connection.execute(sql)
141
  else:
142
  await connection.execute(sql, *data.values())
 
 
 
 
 
143
  except Exception as e:
144
  logger.error(f"PostgreSQL database error: {e}")
145
  print(sql)
@@ -568,10 +574,10 @@ class PGGraphStorage(BaseGraphStorage):
568
 
569
  if dtype == "vertex":
570
  vertex = json.loads(v)
571
- field = json.loads(v).get("properties")
572
  if not field:
573
  field = {}
574
- field["label"] = PGGraphStorage._decode_graph_label(vertex["label"])
575
  d[k] = field
576
  # convert edge from id-label->id by replacing id with node information
577
  # we only do this if the vertex was also returned in the query
@@ -666,73 +672,8 @@ class PGGraphStorage(BaseGraphStorage):
666
  # otherwise return the value stripping out some common special chars
667
  return field.replace("(", "_").replace(")", "")
668
 
669
- @staticmethod
670
- def _wrap_query(query: str, graph_name: str, **params: str) -> str:
671
- """
672
- Convert a cypher query to an Apache Age compatible
673
- sql query by wrapping the cypher query in ag_catalog.cypher,
674
- casting results to agtype and building a select statement
675
-
676
- Args:
677
- query (str): a valid cypher query
678
- graph_name (str): the name of the graph to query
679
- params (dict): parameters for the query
680
-
681
- Returns:
682
- str: an equivalent pgsql query
683
- """
684
-
685
- # pgsql template
686
- template = """SELECT {projection} FROM ag_catalog.cypher('{graph_name}', $$
687
- {query}
688
- $$) AS ({fields})"""
689
-
690
- # if there are any returned fields they must be added to the pgsql query
691
- if "return" in query.lower():
692
- # parse return statement to identify returned fields
693
- fields = (
694
- query.lower()
695
- .split("return")[-1]
696
- .split("distinct")[-1]
697
- .split("order by")[0]
698
- .split("skip")[0]
699
- .split("limit")[0]
700
- .split(",")
701
- )
702
-
703
- # raise exception if RETURN * is found as we can't resolve the fields
704
- if "*" in [x.strip() for x in fields]:
705
- raise ValueError(
706
- "AGE graph does not support 'RETURN *'"
707
- + " statements in Cypher queries"
708
- )
709
-
710
- # get pgsql formatted field names
711
- fields = [
712
- PGGraphStorage._get_col_name(field, idx)
713
- for idx, field in enumerate(fields)
714
- ]
715
-
716
- # build resulting pgsql relation
717
- fields_str = ", ".join(
718
- [field.split(".")[-1] + " agtype" for field in fields]
719
- )
720
-
721
- # if no return statement we still need to return a single field of type agtype
722
- else:
723
- fields_str = "a agtype"
724
-
725
- select_str = "*"
726
-
727
- return template.format(
728
- graph_name=graph_name,
729
- query=query.format(**params),
730
- fields=fields_str,
731
- projection=select_str,
732
- )
733
-
734
  async def _query(
735
- self, query: str, readonly=True, upsert_edge=False, **params: str
736
  ) -> List[Dict[str, Any]]:
737
  """
738
  Query the graph by taking a cypher query, converting it to an
@@ -746,7 +687,7 @@ class PGGraphStorage(BaseGraphStorage):
746
  List[Dict[str, Any]]: a list of dictionaries containing the result set
747
  """
748
  # convert cypher query to pgsql/age query
749
- wrapped_query = self._wrap_query(query, self.graph_name, **params)
750
 
751
  # execute the query, rolling back on an error
752
  try:
@@ -758,22 +699,16 @@ class PGGraphStorage(BaseGraphStorage):
758
  graph_name=self.graph_name,
759
  )
760
  else:
761
- # for upserting edge, need to run the SQL twice, otherwise cannot update the properties. (First time it will try to create the edge, second time is MERGING)
762
- # It is a bug of AGE as of 2025-01-03, hope it can be resolved in the future.
763
- if upsert_edge:
764
- data = await self.db.execute(
765
- f"{wrapped_query};{wrapped_query};",
766
- for_age=True,
767
- graph_name=self.graph_name,
768
- )
769
- else:
770
- data = await self.db.execute(
771
- wrapped_query, for_age=True, graph_name=self.graph_name
772
- )
773
  except Exception as e:
774
  raise PGGraphQueryException(
775
  {
776
- "message": f"Error executing graph query: {query.format(**params)}",
777
  "wrapped": wrapped_query,
778
  "detail": str(e),
779
  }
@@ -788,77 +723,85 @@ class PGGraphStorage(BaseGraphStorage):
788
  return result
789
 
790
  async def has_node(self, node_id: str) -> bool:
791
- entity_name_label = node_id.strip('"')
 
 
 
 
 
792
 
793
- query = """MATCH (n:`{label}`) RETURN count(n) > 0 AS node_exists"""
794
- params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
795
- single_result = (await self._query(query, **params))[0]
796
  logger.debug(
797
  "{%s}:query:{%s}:result:{%s}",
798
  inspect.currentframe().f_code.co_name,
799
- query.format(**params),
800
  single_result["node_exists"],
801
  )
802
 
803
  return single_result["node_exists"]
804
 
805
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
806
- entity_name_label_source = source_node_id.strip('"')
807
- entity_name_label_target = target_node_id.strip('"')
 
 
 
 
 
 
 
 
 
808
 
809
- query = """MATCH (a:`{src_label}`)-[r]-(b:`{tgt_label}`)
810
- RETURN COUNT(r) > 0 AS edge_exists"""
811
- params = {
812
- "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
813
- "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
814
- }
815
- single_result = (await self._query(query, **params))[0]
816
  logger.debug(
817
  "{%s}:query:{%s}:result:{%s}",
818
  inspect.currentframe().f_code.co_name,
819
- query.format(**params),
820
  single_result["edge_exists"],
821
  )
822
  return single_result["edge_exists"]
823
 
824
  async def get_node(self, node_id: str) -> Union[dict, None]:
825
- entity_name_label = node_id.strip('"')
826
- query = """MATCH (n:`{label}`) RETURN n"""
827
- params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
828
- record = await self._query(query, **params)
 
 
829
  if record:
830
  node = record[0]
831
  node_dict = node["n"]
832
  logger.debug(
833
  "{%s}: query: {%s}, result: {%s}",
834
  inspect.currentframe().f_code.co_name,
835
- query.format(**params),
836
  node_dict,
837
  )
838
  return node_dict
839
  return None
840
 
841
  async def node_degree(self, node_id: str) -> int:
842
- entity_name_label = node_id.strip('"')
843
 
844
- query = """MATCH (n:`{label}`)-[]->(x) RETURN count(x) AS total_edge_count"""
845
- params = {"label": PGGraphStorage._encode_graph_label(entity_name_label)}
846
- record = (await self._query(query, **params))[0]
 
 
847
  if record:
848
  edge_count = int(record["total_edge_count"])
849
  logger.debug(
850
  "{%s}:query:{%s}:result:{%s}",
851
  inspect.currentframe().f_code.co_name,
852
- query.format(**params),
853
  edge_count,
854
  )
855
  return edge_count
856
 
857
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
858
- entity_name_label_source = src_id.strip('"')
859
- entity_name_label_target = tgt_id.strip('"')
860
- src_degree = await self.node_degree(entity_name_label_source)
861
- trg_degree = await self.node_degree(entity_name_label_target)
862
 
863
  # Convert None to 0 for addition
864
  src_degree = 0 if src_degree is None else src_degree
@@ -885,23 +828,25 @@ class PGGraphStorage(BaseGraphStorage):
885
  Returns:
886
  list: List of all relationships/edges found
887
  """
888
- entity_name_label_source = source_node_id.strip('"')
889
- entity_name_label_target = target_node_id.strip('"')
890
-
891
- query = """MATCH (a:`{src_label}`)-[r]->(b:`{tgt_label}`)
892
- RETURN properties(r) as edge_properties
893
- LIMIT 1"""
894
- params = {
895
- "src_label": PGGraphStorage._encode_graph_label(entity_name_label_source),
896
- "tgt_label": PGGraphStorage._encode_graph_label(entity_name_label_target),
897
- }
898
- record = await self._query(query, **params)
 
 
899
  if record and record[0] and record[0]["edge_properties"]:
900
  result = record[0]["edge_properties"]
901
  logger.debug(
902
  "{%s}:query:{%s}:result:{%s}",
903
  inspect.currentframe().f_code.co_name,
904
- query.format(**params),
905
  result,
906
  )
907
  return result
@@ -911,24 +856,31 @@ class PGGraphStorage(BaseGraphStorage):
911
  Retrieves all edges (relationships) for a particular node identified by its label.
912
  :return: List of dictionaries containing edge information
913
  """
914
- node_label = source_node_id.strip('"')
 
 
 
 
 
 
 
 
 
915
 
916
- query = """MATCH (n:`{label}`)
917
- OPTIONAL MATCH (n)-[r]-(connected)
918
- RETURN n, r, connected"""
919
- params = {"label": PGGraphStorage._encode_graph_label(node_label)}
920
- results = await self._query(query, **params)
921
  edges = []
922
  for record in results:
923
  source_node = record["n"] if record["n"] else None
924
  connected_node = record["connected"] if record["connected"] else None
925
 
926
  source_label = (
927
- source_node["label"] if source_node and source_node["label"] else None
 
 
928
  )
929
  target_label = (
930
- connected_node["label"]
931
- if connected_node and connected_node["label"]
932
  else None
933
  )
934
 
@@ -950,17 +902,21 @@ class PGGraphStorage(BaseGraphStorage):
950
  node_id: The unique identifier for the node (used as label)
951
  node_data: Dictionary of node properties
952
  """
953
- label = node_id.strip('"')
954
  properties = node_data
955
 
956
- query = """MERGE (n:`{label}`)
957
- SET n += {properties}"""
958
- params = {
959
- "label": PGGraphStorage._encode_graph_label(label),
960
- "properties": PGGraphStorage._format_properties(properties),
961
- }
 
 
 
 
962
  try:
963
- await self._query(query, readonly=False, **params)
964
  logger.debug(
965
  "Upserted node with label '{%s}' and properties: {%s}",
966
  label,
@@ -986,28 +942,30 @@ class PGGraphStorage(BaseGraphStorage):
986
  target_node_id (str): Label of the target node (used as identifier)
987
  edge_data (dict): Dictionary of properties to set on the edge
988
  """
989
- source_node_label = source_node_id.strip('"')
990
- target_node_label = target_node_id.strip('"')
991
  edge_properties = edge_data
992
 
993
- query = """MATCH (source:`{src_label}`)
994
- WITH source
995
- MATCH (target:`{tgt_label}`)
996
- MERGE (source)-[r:DIRECTED]->(target)
997
- SET r += {properties}
998
- RETURN r"""
999
- params = {
1000
- "src_label": PGGraphStorage._encode_graph_label(source_node_label),
1001
- "tgt_label": PGGraphStorage._encode_graph_label(target_node_label),
1002
- "properties": PGGraphStorage._format_properties(edge_properties),
1003
- }
 
 
1004
  # logger.info(f"-- inserting edge after formatted: {params}")
1005
  try:
1006
- await self._query(query, readonly=False, upsert_edge=True, **params)
1007
  logger.debug(
1008
  "Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
1009
- source_node_label,
1010
- target_node_label,
1011
  edge_properties,
1012
  )
1013
  except Exception as e:
 
130
  data: Union[list, dict] = None,
131
  for_age: bool = False,
132
  graph_name: str = None,
133
+ upsert: bool = False,
134
  ):
135
  try:
136
  async with self.pool.acquire() as connection:
 
141
  await connection.execute(sql)
142
  else:
143
  await connection.execute(sql, *data.values())
144
+ except asyncpg.exceptions.UniqueViolationError as e:
145
+ if upsert:
146
+ print("Key value duplicate, but upsert succeeded.")
147
+ else:
148
+ logger.error(f"Upsert error: {e}")
149
  except Exception as e:
150
  logger.error(f"PostgreSQL database error: {e}")
151
  print(sql)
 
574
 
575
  if dtype == "vertex":
576
  vertex = json.loads(v)
577
+ field = vertex.get("properties")
578
  if not field:
579
  field = {}
580
+ field["label"] = PGGraphStorage._decode_graph_label(field["node_id"])
581
  d[k] = field
582
  # convert edge from id-label->id by replacing id with node information
583
  # we only do this if the vertex was also returned in the query
 
672
  # otherwise return the value stripping out some common special chars
673
  return field.replace("(", "_").replace(")", "")
674
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
675
  async def _query(
676
+ self, query: str, readonly: bool = True, upsert: bool = False
677
  ) -> List[Dict[str, Any]]:
678
  """
679
  Query the graph by taking a cypher query, converting it to an
 
687
  List[Dict[str, Any]]: a list of dictionaries containing the result set
688
  """
689
  # convert cypher query to pgsql/age query
690
+ wrapped_query = query
691
 
692
  # execute the query, rolling back on an error
693
  try:
 
699
  graph_name=self.graph_name,
700
  )
701
  else:
702
+ data = await self.db.execute(
703
+ wrapped_query,
704
+ for_age=True,
705
+ graph_name=self.graph_name,
706
+ upsert=upsert,
707
+ )
 
 
 
 
 
 
708
  except Exception as e:
709
  raise PGGraphQueryException(
710
  {
711
+ "message": f"Error executing graph query: {query}",
712
  "wrapped": wrapped_query,
713
  "detail": str(e),
714
  }
 
723
  return result
724
 
725
  async def has_node(self, node_id: str) -> bool:
726
+ entity_name_label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
727
+
728
+ query = """SELECT * FROM cypher('%s', $$
729
+ MATCH (n:Entity {node_id: "%s"})
730
+ RETURN count(n) > 0 AS node_exists
731
+ $$) AS (node_exists bool)""" % (self.graph_name, entity_name_label)
732
 
733
+ single_result = (await self._query(query))[0]
 
 
734
  logger.debug(
735
  "{%s}:query:{%s}:result:{%s}",
736
  inspect.currentframe().f_code.co_name,
737
+ query,
738
  single_result["node_exists"],
739
  )
740
 
741
  return single_result["node_exists"]
742
 
743
  async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
744
+ src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
745
+ tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
746
+
747
+ query = """SELECT * FROM cypher('%s', $$
748
+ MATCH (a:Entity {node_id: "%s"})-[r]-(b:Entity {node_id: "%s"})
749
+ RETURN COUNT(r) > 0 AS edge_exists
750
+ $$) AS (edge_exists bool)""" % (
751
+ self.graph_name,
752
+ src_label,
753
+ tgt_label,
754
+ )
755
 
756
+ single_result = (await self._query(query))[0]
 
 
 
 
 
 
757
  logger.debug(
758
  "{%s}:query:{%s}:result:{%s}",
759
  inspect.currentframe().f_code.co_name,
760
+ query,
761
  single_result["edge_exists"],
762
  )
763
  return single_result["edge_exists"]
764
 
765
  async def get_node(self, node_id: str) -> Union[dict, None]:
766
+ label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
767
+ query = """SELECT * FROM cypher('%s', $$
768
+ MATCH (n:Entity {node_id: "%s"})
769
+ RETURN n
770
+ $$) AS (n agtype)""" % (self.graph_name, label)
771
+ record = await self._query(query)
772
  if record:
773
  node = record[0]
774
  node_dict = node["n"]
775
  logger.debug(
776
  "{%s}: query: {%s}, result: {%s}",
777
  inspect.currentframe().f_code.co_name,
778
+ query,
779
  node_dict,
780
  )
781
  return node_dict
782
  return None
783
 
784
  async def node_degree(self, node_id: str) -> int:
785
+ label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
786
 
787
+ query = """SELECT * FROM cypher('%s', $$
788
+ MATCH (n:Entity {node_id: "%s"})-[]->(x)
789
+ RETURN count(x) AS total_edge_count
790
+ $$) AS (total_edge_count integer)""" % (self.graph_name, label)
791
+ record = (await self._query(query))[0]
792
  if record:
793
  edge_count = int(record["total_edge_count"])
794
  logger.debug(
795
  "{%s}:query:{%s}:result:{%s}",
796
  inspect.currentframe().f_code.co_name,
797
+ query,
798
  edge_count,
799
  )
800
  return edge_count
801
 
802
  async def edge_degree(self, src_id: str, tgt_id: str) -> int:
803
+ src_degree = await self.node_degree(src_id)
804
+ trg_degree = await self.node_degree(tgt_id)
 
 
805
 
806
  # Convert None to 0 for addition
807
  src_degree = 0 if src_degree is None else src_degree
 
828
  Returns:
829
  list: List of all relationships/edges found
830
  """
831
+ src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
832
+ tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
833
+
834
+ query = """SELECT * FROM cypher('%s', $$
835
+ MATCH (a:Entity {node_id: "%s"})-[r]->(b:Entity {node_id: "%s"})
836
+ RETURN properties(r) as edge_properties
837
+ LIMIT 1
838
+ $$) AS (edge_properties agtype)""" % (
839
+ self.graph_name,
840
+ src_label,
841
+ tgt_label,
842
+ )
843
+ record = await self._query(query)
844
  if record and record[0] and record[0]["edge_properties"]:
845
  result = record[0]["edge_properties"]
846
  logger.debug(
847
  "{%s}:query:{%s}:result:{%s}",
848
  inspect.currentframe().f_code.co_name,
849
+ query,
850
  result,
851
  )
852
  return result
 
856
  Retrieves all edges (relationships) for a particular node identified by its label.
857
  :return: List of dictionaries containing edge information
858
  """
859
+ label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
860
+
861
+ query = """SELECT * FROM cypher('%s', $$
862
+ MATCH (n:Entity {node_id: "%s"})
863
+ OPTIONAL MATCH (n)-[r]-(connected)
864
+ RETURN n, r, connected
865
+ $$) AS (n agtype, r agtype, connected agtype)""" % (
866
+ self.graph_name,
867
+ label,
868
+ )
869
 
870
+ results = await self._query(query)
 
 
 
 
871
  edges = []
872
  for record in results:
873
  source_node = record["n"] if record["n"] else None
874
  connected_node = record["connected"] if record["connected"] else None
875
 
876
  source_label = (
877
+ source_node["node_id"]
878
+ if source_node and source_node["node_id"]
879
+ else None
880
  )
881
  target_label = (
882
+ connected_node["node_id"]
883
+ if connected_node and connected_node["node_id"]
884
  else None
885
  )
886
 
 
902
  node_id: The unique identifier for the node (used as label)
903
  node_data: Dictionary of node properties
904
  """
905
+ label = PGGraphStorage._encode_graph_label(node_id.strip('"'))
906
  properties = node_data
907
 
908
+ query = """SELECT * FROM cypher('%s', $$
909
+ MERGE (n:Entity {node_id: "%s"})
910
+ SET n += %s
911
+ RETURN n
912
+ $$) AS (n agtype)""" % (
913
+ self.graph_name,
914
+ label,
915
+ PGGraphStorage._format_properties(properties),
916
+ )
917
+
918
  try:
919
+ await self._query(query, readonly=False, upsert=True)
920
  logger.debug(
921
  "Upserted node with label '{%s}' and properties: {%s}",
922
  label,
 
942
  target_node_id (str): Label of the target node (used as identifier)
943
  edge_data (dict): Dictionary of properties to set on the edge
944
  """
945
+ src_label = PGGraphStorage._encode_graph_label(source_node_id.strip('"'))
946
+ tgt_label = PGGraphStorage._encode_graph_label(target_node_id.strip('"'))
947
  edge_properties = edge_data
948
 
949
+ query = """SELECT * FROM cypher('%s', $$
950
+ MATCH (source:Entity {node_id: "%s"})
951
+ WITH source
952
+ MATCH (target:Entity {node_id: "%s"})
953
+ MERGE (source)-[r:DIRECTED]->(target)
954
+ SET r += %s
955
+ RETURN r
956
+ $$) AS (r agtype)""" % (
957
+ self.graph_name,
958
+ src_label,
959
+ tgt_label,
960
+ PGGraphStorage._format_properties(edge_properties),
961
+ )
962
  # logger.info(f"-- inserting edge after formatted: {params}")
963
  try:
964
+ await self._query(query, readonly=False, upsert=True)
965
  logger.debug(
966
  "Upserted edge from '{%s}' to '{%s}' with properties: {%s}",
967
+ src_label,
968
+ tgt_label,
969
  edge_properties,
970
  )
971
  except Exception as e: