zrguo commited on
Commit
af14936
·
unverified ·
2 Parent(s): 44afdce ddd0cf7

Merge pull request #1032 from ArindamRoy23/main

Browse files
README.md CHANGED
@@ -176,6 +176,8 @@ class QueryParam:
176
  """Maximum number of tokens allocated for relationship descriptions in global retrieval."""
177
  max_token_for_local_context: int = 4000
178
  """Maximum number of tokens allocated for entity descriptions in local retrieval."""
 
 
179
  ...
180
  ```
181
 
 
176
  """Maximum number of tokens allocated for relationship descriptions in global retrieval."""
177
  max_token_for_local_context: int = 4000
178
  """Maximum number of tokens allocated for entity descriptions in local retrieval."""
179
+ ids: list[str] | None = None # ONLY SUPPORTED FOR PG VECTOR DBs
180
+ """List of ids to filter the RAG."""
181
  ...
182
  ```
183
 
lightrag/base.py CHANGED
@@ -81,6 +81,9 @@ class QueryParam:
81
  history_turns: int = 3
82
  """Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
83
 
 
 
 
84
 
85
  @dataclass
86
  class StorageNameSpace(ABC):
@@ -107,7 +110,9 @@ class BaseVectorStorage(StorageNameSpace, ABC):
107
  meta_fields: set[str] = field(default_factory=set)
108
 
109
  @abstractmethod
110
- async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
 
 
111
  """Query the vector storage and retrieve top_k results."""
112
 
113
  @abstractmethod
 
81
  history_turns: int = 3
82
  """Number of complete conversation turns (user-assistant pairs) to consider in the response context."""
83
 
84
+ ids: list[str] | None = None
85
+ """List of ids to filter the results."""
86
+
87
 
88
  @dataclass
89
  class StorageNameSpace(ABC):
 
110
  meta_fields: set[str] = field(default_factory=set)
111
 
112
  @abstractmethod
113
+ async def query(
114
+ self, query: str, top_k: int, ids: list[str] | None = None
115
+ ) -> list[dict[str, Any]]:
116
  """Query the vector storage and retrieve top_k results."""
117
 
118
  @abstractmethod
lightrag/kg/postgres_impl.py CHANGED
@@ -438,6 +438,8 @@ class PGVectorStorage(BaseVectorStorage):
438
  "entity_name": item["entity_name"],
439
  "content": item["content"],
440
  "content_vector": json.dumps(item["__vector__"].tolist()),
 
 
441
  }
442
  return upsert_sql, data
443
 
@@ -450,6 +452,8 @@ class PGVectorStorage(BaseVectorStorage):
450
  "target_id": item["tgt_id"],
451
  "content": item["content"],
452
  "content_vector": json.dumps(item["__vector__"].tolist()),
 
 
453
  }
454
  return upsert_sql, data
455
 
@@ -492,13 +496,20 @@ class PGVectorStorage(BaseVectorStorage):
492
  await self.db.execute(upsert_sql, data)
493
 
494
  #################### query method ###############
495
- async def query(self, query: str, top_k: int) -> list[dict[str, Any]]:
 
 
496
  embeddings = await self.embedding_func([query])
497
  embedding = embeddings[0]
498
  embedding_string = ",".join(map(str, embedding))
499
 
 
 
 
 
 
500
  sql = SQL_TEMPLATES[self.base_namespace].format(
501
- embedding_string=embedding_string
502
  )
503
  params = {
504
  "workspace": self.db.workspace,
@@ -1491,6 +1502,7 @@ TABLES = {
1491
  content_vector VECTOR,
1492
  create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1493
  update_time TIMESTAMP,
 
1494
  CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
1495
  )"""
1496
  },
@@ -1504,6 +1516,7 @@ TABLES = {
1504
  content_vector VECTOR,
1505
  create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1506
  update_time TIMESTAMP,
 
1507
  CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
1508
  )"""
1509
  },
@@ -1586,8 +1599,9 @@ SQL_TEMPLATES = {
1586
  content_vector=EXCLUDED.content_vector,
1587
  update_time = CURRENT_TIMESTAMP
1588
  """,
1589
- "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content, content_vector)
1590
- VALUES ($1, $2, $3, $4, $5)
 
1591
  ON CONFLICT (workspace,id) DO UPDATE
1592
  SET entity_name=EXCLUDED.entity_name,
1593
  content=EXCLUDED.content,
@@ -1595,8 +1609,8 @@ SQL_TEMPLATES = {
1595
  update_time=CURRENT_TIMESTAMP
1596
  """,
1597
  "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
1598
- target_id, content, content_vector)
1599
- VALUES ($1, $2, $3, $4, $5, $6)
1600
  ON CONFLICT (workspace,id) DO UPDATE
1601
  SET source_id=EXCLUDED.source_id,
1602
  target_id=EXCLUDED.target_id,
@@ -1604,21 +1618,21 @@ SQL_TEMPLATES = {
1604
  content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
1605
  """,
1606
  # SQL for VectorStorage
1607
- "entities": """SELECT entity_name FROM
1608
- (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1609
- FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
1610
- WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1611
- """,
1612
- "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
1613
- (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1614
- FROM LIGHTRAG_VDB_RELATION where workspace=$1)
1615
- WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1616
- """,
1617
- "chunks": """SELECT id FROM
1618
- (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1619
- FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
1620
- WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1621
- """,
1622
  # DROP tables
1623
  "drop_all": """
1624
  DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
@@ -1642,4 +1656,55 @@ SQL_TEMPLATES = {
1642
  "drop_vdb_relation": """
1643
  DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
1644
  """,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1645
  }
 
438
  "entity_name": item["entity_name"],
439
  "content": item["content"],
440
  "content_vector": json.dumps(item["__vector__"].tolist()),
441
+ "chunk_id": item["source_id"],
442
+ # TODO: add document_id
443
  }
444
  return upsert_sql, data
445
 
 
452
  "target_id": item["tgt_id"],
453
  "content": item["content"],
454
  "content_vector": json.dumps(item["__vector__"].tolist()),
455
+ "chunk_id": item["source_id"],
456
+ # TODO: add document_id
457
  }
458
  return upsert_sql, data
459
 
 
496
  await self.db.execute(upsert_sql, data)
497
 
498
  #################### query method ###############
499
+ async def query(
500
+ self, query: str, top_k: int, ids: list[str] | None = None
501
+ ) -> list[dict[str, Any]]:
502
  embeddings = await self.embedding_func([query])
503
  embedding = embeddings[0]
504
  embedding_string = ",".join(map(str, embedding))
505
 
506
+ if ids:
507
+ formatted_ids = ",".join(f"'{id}'" for id in ids)
508
+ else:
509
+ formatted_ids = "NULL"
510
+
511
  sql = SQL_TEMPLATES[self.base_namespace].format(
512
+ embedding_string=embedding_string, doc_ids=formatted_ids
513
  )
514
  params = {
515
  "workspace": self.db.workspace,
 
1502
  content_vector VECTOR,
1503
  create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1504
  update_time TIMESTAMP,
1505
+ chunk_id VARCHAR(255) NULL,
1506
  CONSTRAINT LIGHTRAG_VDB_ENTITY_PK PRIMARY KEY (workspace, id)
1507
  )"""
1508
  },
 
1516
  content_vector VECTOR,
1517
  create_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
1518
  update_time TIMESTAMP,
1519
+ chunk_id VARCHAR(255) NULL,
1520
  CONSTRAINT LIGHTRAG_VDB_RELATION_PK PRIMARY KEY (workspace, id)
1521
  )"""
1522
  },
 
1599
  content_vector=EXCLUDED.content_vector,
1600
  update_time = CURRENT_TIMESTAMP
1601
  """,
1602
+ "upsert_entity": """INSERT INTO LIGHTRAG_VDB_ENTITY (workspace, id, entity_name, content,
1603
+ content_vector, chunk_id)
1604
+ VALUES ($1, $2, $3, $4, $5, $6)
1605
  ON CONFLICT (workspace,id) DO UPDATE
1606
  SET entity_name=EXCLUDED.entity_name,
1607
  content=EXCLUDED.content,
 
1609
  update_time=CURRENT_TIMESTAMP
1610
  """,
1611
  "upsert_relationship": """INSERT INTO LIGHTRAG_VDB_RELATION (workspace, id, source_id,
1612
+ target_id, content, content_vector, chunk_id)
1613
+ VALUES ($1, $2, $3, $4, $5, $6, $7)
1614
  ON CONFLICT (workspace,id) DO UPDATE
1615
  SET source_id=EXCLUDED.source_id,
1616
  target_id=EXCLUDED.target_id,
 
1618
  content_vector=EXCLUDED.content_vector, update_time = CURRENT_TIMESTAMP
1619
  """,
1620
  # SQL for VectorStorage
1621
+ # "entities": """SELECT entity_name FROM
1622
+ # (SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1623
+ # FROM LIGHTRAG_VDB_ENTITY where workspace=$1)
1624
+ # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1625
+ # """,
1626
+ # "relationships": """SELECT source_id as src_id, target_id as tgt_id FROM
1627
+ # (SELECT id, source_id,target_id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1628
+ # FROM LIGHTRAG_VDB_RELATION where workspace=$1)
1629
+ # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1630
+ # """,
1631
+ # "chunks": """SELECT id FROM
1632
+ # (SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1633
+ # FROM LIGHTRAG_DOC_CHUNKS where workspace=$1)
1634
+ # WHERE distance>$2 ORDER BY distance DESC LIMIT $3
1635
+ # """,
1636
  # DROP tables
1637
  "drop_all": """
1638
  DROP TABLE IF EXISTS LIGHTRAG_DOC_FULL CASCADE;
 
1656
  "drop_vdb_relation": """
1657
  DROP TABLE IF EXISTS LIGHTRAG_VDB_RELATION CASCADE;
1658
  """,
1659
+ "relationships": """
1660
+ WITH relevant_chunks AS (
1661
+ SELECT id as chunk_id
1662
+ FROM LIGHTRAG_DOC_CHUNKS
1663
+ WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
1664
+ )
1665
+ SELECT source_id as src_id, target_id as tgt_id
1666
+ FROM (
1667
+ SELECT r.id, r.source_id, r.target_id, 1 - (r.content_vector <=> '[{embedding_string}]'::vector) as distance
1668
+ FROM LIGHTRAG_VDB_RELATION r
1669
+ WHERE r.workspace=$1
1670
+ AND r.chunk_id IN (SELECT chunk_id FROM relevant_chunks)
1671
+ ) filtered
1672
+ WHERE distance>$2
1673
+ ORDER BY distance DESC
1674
+ LIMIT $3
1675
+ """,
1676
+ "entities": """
1677
+ WITH relevant_chunks AS (
1678
+ SELECT id as chunk_id
1679
+ FROM LIGHTRAG_DOC_CHUNKS
1680
+ WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
1681
+ )
1682
+ SELECT entity_name FROM
1683
+ (
1684
+ SELECT id, entity_name, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1685
+ FROM LIGHTRAG_VDB_ENTITY
1686
+ where workspace=$1
1687
+ AND chunk_id IN (SELECT chunk_id FROM relevant_chunks)
1688
+ )
1689
+ WHERE distance>$2
1690
+ ORDER BY distance DESC
1691
+ LIMIT $3
1692
+ """,
1693
+ "chunks": """
1694
+ WITH relevant_chunks AS (
1695
+ SELECT id as chunk_id
1696
+ FROM LIGHTRAG_DOC_CHUNKS
1697
+ WHERE {doc_ids} IS NULL OR full_doc_id = ANY(ARRAY[{doc_ids}])
1698
+ )
1699
+ SELECT id FROM
1700
+ (
1701
+ SELECT id, 1 - (content_vector <=> '[{embedding_string}]'::vector) as distance
1702
+ FROM LIGHTRAG_DOC_CHUNKS
1703
+ where workspace=$1
1704
+ AND id IN (SELECT chunk_id FROM relevant_chunks)
1705
+ )
1706
+ WHERE distance>$2
1707
+ ORDER BY distance DESC
1708
+ LIMIT $3
1709
+ """,
1710
  }
lightrag/operate.py CHANGED
@@ -962,7 +962,10 @@ async def mix_kg_vector_query(
962
  try:
963
  # Reduce top_k for vector search in hybrid mode since we have structured information from KG
964
  mix_topk = min(10, query_param.top_k)
965
- results = await chunks_vdb.query(augmented_query, top_k=mix_topk)
 
 
 
966
  if not results:
967
  return None
968
 
@@ -1171,7 +1174,11 @@ async def _get_node_data(
1171
  logger.info(
1172
  f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
1173
  )
1174
- results = await entities_vdb.query(query, top_k=query_param.top_k)
 
 
 
 
1175
  if not len(results):
1176
  return "", "", ""
1177
  # get entity information
@@ -1424,7 +1431,10 @@ async def _get_edge_data(
1424
  logger.info(
1425
  f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
1426
  )
1427
- results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
 
 
 
1428
 
1429
  if not len(results):
1430
  return "", "", ""
@@ -1673,7 +1683,9 @@ async def naive_query(
1673
  if cached_response is not None:
1674
  return cached_response
1675
 
1676
- results = await chunks_vdb.query(query, top_k=query_param.top_k)
 
 
1677
  if not len(results):
1678
  return PROMPTS["fail_response"]
1679
 
 
962
  try:
963
  # Reduce top_k for vector search in hybrid mode since we have structured information from KG
964
  mix_topk = min(10, query_param.top_k)
965
+ # TODO: add ids to the query
966
+ results = await chunks_vdb.query(
967
+ augmented_query, top_k=mix_topk, ids=query_param.ids
968
+ )
969
  if not results:
970
  return None
971
 
 
1174
  logger.info(
1175
  f"Query nodes: {query}, top_k: {query_param.top_k}, cosine: {entities_vdb.cosine_better_than_threshold}"
1176
  )
1177
+
1178
+ results = await entities_vdb.query(
1179
+ query, top_k=query_param.top_k, ids=query_param.ids
1180
+ )
1181
+
1182
  if not len(results):
1183
  return "", "", ""
1184
  # get entity information
 
1431
  logger.info(
1432
  f"Query edges: {keywords}, top_k: {query_param.top_k}, cosine: {relationships_vdb.cosine_better_than_threshold}"
1433
  )
1434
+
1435
+ results = await relationships_vdb.query(
1436
+ keywords, top_k=query_param.top_k, ids=query_param.ids
1437
+ )
1438
 
1439
  if not len(results):
1440
  return "", "", ""
 
1683
  if cached_response is not None:
1684
  return cached_response
1685
 
1686
+ results = await chunks_vdb.query(
1687
+ query, top_k=query_param.top_k, ids=query_param.ids
1688
+ )
1689
  if not len(results):
1690
  return PROMPTS["fail_response"]
1691