zrguo commited on
Commit
aaf3e3f
·
unverified ·
2 Parent(s): 492af66 764ded8

Merge pull request #447 from spo0nman/pkaushal/vectordb-chroma 

Browse files
lightrag/kg/chroma_impl.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from dataclasses import dataclass
3
+ from typing import Union
4
+ import numpy as np
5
+ from chromadb import HttpClient
6
+ from chromadb.config import Settings
7
+ from lightrag.base import BaseVectorStorage
8
+ from lightrag.utils import logger
9
+
10
+
11
+ @dataclass
12
+ class ChromaVectorDBStorage(BaseVectorStorage):
13
+ """ChromaDB vector storage implementation."""
14
+
15
+ cosine_better_than_threshold: float = 0.2
16
+
17
+ def __post_init__(self):
18
+ try:
19
+ # Use global config value if specified, otherwise use default
20
+ self.cosine_better_than_threshold = self.global_config.get(
21
+ "cosine_better_than_threshold", self.cosine_better_than_threshold
22
+ )
23
+
24
+ config = self.global_config.get("vector_db_storage_cls_kwargs", {})
25
+ user_collection_settings = config.get("collection_settings", {})
26
+ # Default HNSW index settings for ChromaDB
27
+ default_collection_settings = {
28
+ # Distance metric used for similarity search (cosine similarity)
29
+ "hnsw:space": "cosine",
30
+ # Number of nearest neighbors to explore during index construction
31
+ # Higher values = better recall but slower indexing
32
+ "hnsw:construction_ef": 128,
33
+ # Number of nearest neighbors to explore during search
34
+ # Higher values = better recall but slower search
35
+ "hnsw:search_ef": 128,
36
+ # Number of connections per node in the HNSW graph
37
+ # Higher values = better recall but more memory usage
38
+ "hnsw:M": 16,
39
+ # Number of vectors to process in one batch during indexing
40
+ "hnsw:batch_size": 100,
41
+ # Number of updates before forcing index synchronization
42
+ # Lower values = more frequent syncs but slower indexing
43
+ "hnsw:sync_threshold": 1000,
44
+ }
45
+ collection_settings = {
46
+ **default_collection_settings,
47
+ **user_collection_settings,
48
+ }
49
+
50
+ auth_provider = config.get(
51
+ "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
52
+ )
53
+ auth_credentials = config.get("auth_token", "secret-token")
54
+ headers = {}
55
+
56
+ if "token_authn" in auth_provider:
57
+ headers = {
58
+ config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
59
+ }
60
+ elif "basic_authn" in auth_provider:
61
+ auth_credentials = config.get("auth_credentials", "admin:admin")
62
+
63
+ self._client = HttpClient(
64
+ host=config.get("host", "localhost"),
65
+ port=config.get("port", 8000),
66
+ headers=headers,
67
+ settings=Settings(
68
+ chroma_api_impl="rest",
69
+ chroma_client_auth_provider=auth_provider,
70
+ chroma_client_auth_credentials=auth_credentials,
71
+ allow_reset=True,
72
+ anonymized_telemetry=False,
73
+ ),
74
+ )
75
+
76
+ self._collection = self._client.get_or_create_collection(
77
+ name=self.namespace,
78
+ metadata={
79
+ **collection_settings,
80
+ "dimension": self.embedding_func.embedding_dim,
81
+ },
82
+ )
83
+ # Use batch size from collection settings if specified
84
+ self._max_batch_size = self.global_config.get(
85
+ "embedding_batch_num", collection_settings.get("hnsw:batch_size", 32)
86
+ )
87
+ except Exception as e:
88
+ logger.error(f"ChromaDB initialization failed: {str(e)}")
89
+ raise
90
+
91
+ async def upsert(self, data: dict[str, dict]):
92
+ if not data:
93
+ logger.warning("Empty data provided to vector DB")
94
+ return []
95
+
96
+ try:
97
+ ids = list(data.keys())
98
+ documents = [v["content"] for v in data.values()]
99
+ metadatas = [
100
+ {k: v for k, v in item.items() if k in self.meta_fields}
101
+ or {"_default": "true"}
102
+ for item in data.values()
103
+ ]
104
+
105
+ # Process in batches
106
+ batches = [
107
+ documents[i : i + self._max_batch_size]
108
+ for i in range(0, len(documents), self._max_batch_size)
109
+ ]
110
+
111
+ embedding_tasks = [self.embedding_func(batch) for batch in batches]
112
+ embeddings_list = []
113
+
114
+ # Pre-allocate embeddings_list with known size
115
+ embeddings_list = [None] * len(embedding_tasks)
116
+
117
+ # Use asyncio.gather instead of as_completed if order doesn't matter
118
+ embeddings_results = await asyncio.gather(*embedding_tasks)
119
+ embeddings_list = list(embeddings_results)
120
+
121
+ embeddings = np.concatenate(embeddings_list)
122
+
123
+ # Upsert in batches
124
+ for i in range(0, len(ids), self._max_batch_size):
125
+ batch_slice = slice(i, i + self._max_batch_size)
126
+
127
+ self._collection.upsert(
128
+ ids=ids[batch_slice],
129
+ embeddings=embeddings[batch_slice].tolist(),
130
+ documents=documents[batch_slice],
131
+ metadatas=metadatas[batch_slice],
132
+ )
133
+
134
+ return ids
135
+
136
+ except Exception as e:
137
+ logger.error(f"Error during ChromaDB upsert: {str(e)}")
138
+ raise
139
+
140
+ async def query(self, query: str, top_k=5) -> Union[dict, list[dict]]:
141
+ try:
142
+ embedding = await self.embedding_func([query])
143
+
144
+ results = self._collection.query(
145
+ query_embeddings=embedding.tolist(),
146
+ n_results=top_k * 2, # Request more results to allow for filtering
147
+ include=["metadatas", "distances", "documents"],
148
+ )
149
+
150
+ # Filter results by cosine similarity threshold and take top k
151
+ # We request 2x results initially to have enough after filtering
152
+ # ChromaDB returns cosine similarity (1 = identical, 0 = orthogonal)
153
+ # We convert to distance (0 = identical, 1 = orthogonal) via (1 - similarity)
154
+ # Only keep results with distance below threshold, then take top k
155
+ return [
156
+ {
157
+ "id": results["ids"][0][i],
158
+ "distance": 1 - results["distances"][0][i],
159
+ "content": results["documents"][0][i],
160
+ **results["metadatas"][0][i],
161
+ }
162
+ for i in range(len(results["ids"][0]))
163
+ if (1 - results["distances"][0][i]) >= self.cosine_better_than_threshold
164
+ ][:top_k]
165
+
166
+ except Exception as e:
167
+ logger.error(f"Error during ChromaDB query: {str(e)}")
168
+ raise
169
+
170
+ async def index_done_callback(self):
171
+ # ChromaDB handles persistence automatically
172
+ pass
lightrag/lightrag.py CHANGED
@@ -76,6 +76,7 @@ OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage
76
  OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
77
  MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
78
  MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
 
79
 
80
 
81
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
@@ -263,6 +264,7 @@ class LightRAG:
263
  "NanoVectorDBStorage": NanoVectorDBStorage,
264
  "OracleVectorDBStorage": OracleVectorDBStorage,
265
  "MilvusVectorDBStorge": MilvusVectorDBStorge,
 
266
  # graph storage
267
  "NetworkXStorage": NetworkXStorage,
268
  "Neo4JStorage": Neo4JStorage,
 
76
  OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
77
  MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
78
  MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
79
+ ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
80
 
81
 
82
  def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
 
264
  "NanoVectorDBStorage": NanoVectorDBStorage,
265
  "OracleVectorDBStorage": OracleVectorDBStorage,
266
  "MilvusVectorDBStorge": MilvusVectorDBStorge,
267
+ "ChromaVectorDBStorage": ChromaVectorDBStorage,
268
  # graph storage
269
  "NetworkXStorage": NetworkXStorage,
270
  "Neo4JStorage": Neo4JStorage,
test_chromadb.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ from lightrag import LightRAG, QueryParam
4
+ from lightrag.llm import gpt_4o_mini_complete, openai_embedding
5
+ from lightrag.utils import EmbeddingFunc
6
+ import numpy as np
7
+
8
+ #########
9
+ # Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
10
+ # import nest_asyncio
11
+ # nest_asyncio.apply()
12
+ #########
13
+ WORKING_DIR = "./chromadb_test_dir"
14
+ if not os.path.exists(WORKING_DIR):
15
+ os.mkdir(WORKING_DIR)
16
+
17
+ # ChromaDB Configuration
18
+ CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
19
+ CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
20
+ CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
21
+ CHROMADB_AUTH_PROVIDER = os.environ.get(
22
+ "CHROMADB_AUTH_PROVIDER", "chromadb.auth.token_authn.TokenAuthClientProvider"
23
+ )
24
+ CHROMADB_AUTH_HEADER = os.environ.get("CHROMADB_AUTH_HEADER", "X-Chroma-Token")
25
+
26
+ # Embedding Configuration and Functions
27
+ EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
28
+ EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
29
+
30
+ # ChromaDB requires knowing the dimension of embeddings upfront when
31
+ # creating a collection. The embedding dimension is model-specific
32
+ # (e.g. text-embedding-3-large uses 3072 dimensions)
33
+ # we dynamically determine it by running a test embedding
34
+ # and then pass it to the ChromaDBStorage class
35
+
36
+
37
+ async def embedding_func(texts: list[str]) -> np.ndarray:
38
+ return await openai_embedding(
39
+ texts,
40
+ model=EMBEDDING_MODEL,
41
+ )
42
+
43
+
44
+ async def get_embedding_dimension():
45
+ test_text = ["This is a test sentence."]
46
+ embedding = await embedding_func(test_text)
47
+ return embedding.shape[1]
48
+
49
+
50
+ async def create_embedding_function_instance():
51
+ # Get embedding dimension
52
+ embedding_dimension = await get_embedding_dimension()
53
+ # Create embedding function instance
54
+ return EmbeddingFunc(
55
+ embedding_dim=embedding_dimension,
56
+ max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
57
+ func=embedding_func,
58
+ )
59
+
60
+
61
+ async def initialize_rag():
62
+ embedding_func_instance = await create_embedding_function_instance()
63
+
64
+ return LightRAG(
65
+ working_dir=WORKING_DIR,
66
+ llm_model_func=gpt_4o_mini_complete,
67
+ embedding_func=embedding_func_instance,
68
+ vector_storage="ChromaVectorDBStorage",
69
+ log_level="DEBUG",
70
+ embedding_batch_num=32,
71
+ vector_db_storage_cls_kwargs={
72
+ "host": CHROMADB_HOST,
73
+ "port": CHROMADB_PORT,
74
+ "auth_token": CHROMADB_AUTH_TOKEN,
75
+ "auth_provider": CHROMADB_AUTH_PROVIDER,
76
+ "auth_header_name": CHROMADB_AUTH_HEADER,
77
+ "collection_settings": {
78
+ "hnsw:space": "cosine",
79
+ "hnsw:construction_ef": 128,
80
+ "hnsw:search_ef": 128,
81
+ "hnsw:M": 16,
82
+ "hnsw:batch_size": 100,
83
+ "hnsw:sync_threshold": 1000,
84
+ },
85
+ },
86
+ )
87
+
88
+
89
+ # Run the initialization
90
+ rag = asyncio.run(initialize_rag())
91
+
92
+ # with open("./dickens/book.txt", "r", encoding="utf-8") as f:
93
+ # rag.insert(f.read())
94
+
95
+ # Perform naive search
96
+ print(
97
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="naive"))
98
+ )
99
+
100
+ # Perform local search
101
+ print(
102
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="local"))
103
+ )
104
+
105
+ # Perform global search
106
+ print(
107
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="global"))
108
+ )
109
+
110
+ # Perform hybrid search
111
+ print(
112
+ rag.query("What are the top themes in this story?", param=QueryParam(mode="hybrid"))
113
+ )