destinyhhx commited on
Commit
bef4c6d
·
1 Parent(s): f533b76

Fix embedding type conversion issue in the query function of chroma_impl; chroma_impl supports local persistent client: PersistentClient

Browse files
examples/test_chromadb.py CHANGED
@@ -15,6 +15,10 @@ if not os.path.exists(WORKING_DIR):
15
  os.mkdir(WORKING_DIR)
16
 
17
  # ChromaDB Configuration
 
 
 
 
18
  CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
19
  CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
20
  CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
@@ -60,30 +64,50 @@ async def create_embedding_function_instance():
60
 
61
  async def initialize_rag():
62
  embedding_func_instance = await create_embedding_function_instance()
63
-
64
- return LightRAG(
65
- working_dir=WORKING_DIR,
66
- llm_model_func=gpt_4o_mini_complete,
67
- embedding_func=embedding_func_instance,
68
- vector_storage="ChromaVectorDBStorage",
69
- log_level="DEBUG",
70
- embedding_batch_num=32,
71
- vector_db_storage_cls_kwargs={
72
- "host": CHROMADB_HOST,
73
- "port": CHROMADB_PORT,
74
- "auth_token": CHROMADB_AUTH_TOKEN,
75
- "auth_provider": CHROMADB_AUTH_PROVIDER,
76
- "auth_header_name": CHROMADB_AUTH_HEADER,
77
- "collection_settings": {
78
- "hnsw:space": "cosine",
79
- "hnsw:construction_ef": 128,
80
- "hnsw:search_ef": 128,
81
- "hnsw:M": 16,
82
- "hnsw:batch_size": 100,
83
- "hnsw:sync_threshold": 1000,
84
  },
85
- },
86
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  # Run the initialization
 
15
  os.mkdir(WORKING_DIR)
16
 
17
  # ChromaDB Configuration
18
+ CHROMADB_USE_LOCAL_PERSISTENT = False
19
+ # Local PersistentClient Configuration
20
+ CHROMADB_LOCAL_PATH = os.environ.get("CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data"))
21
+ # Remote HttpClient Configuration
22
  CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
23
  CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
24
  CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
 
64
 
65
  async def initialize_rag():
66
  embedding_func_instance = await create_embedding_function_instance()
67
+ if CHROMADB_USE_LOCAL_PERSISTENT:
68
+ return LightRAG(
69
+ working_dir=WORKING_DIR,
70
+ llm_model_func=gpt_4o_mini_complete,
71
+ embedding_func=embedding_func_instance,
72
+ vector_storage="ChromaVectorDBStorage",
73
+ log_level="DEBUG",
74
+ embedding_batch_num=32,
75
+ vector_db_storage_cls_kwargs={
76
+ "local_path": CHROMADB_LOCAL_PATH,
77
+ "collection_settings": {
78
+ "hnsw:space": "cosine",
79
+ "hnsw:construction_ef": 128,
80
+ "hnsw:search_ef": 128,
81
+ "hnsw:M": 16,
82
+ "hnsw:batch_size": 100,
83
+ "hnsw:sync_threshold": 1000,
84
+ },
 
 
 
85
  },
86
+ )
87
+ else:
88
+ return LightRAG(
89
+ working_dir=WORKING_DIR,
90
+ llm_model_func=gpt_4o_mini_complete,
91
+ embedding_func=embedding_func_instance,
92
+ vector_storage="ChromaVectorDBStorage",
93
+ log_level="DEBUG",
94
+ embedding_batch_num=32,
95
+ vector_db_storage_cls_kwargs={
96
+ "host": CHROMADB_HOST,
97
+ "port": CHROMADB_PORT,
98
+ "auth_token": CHROMADB_AUTH_TOKEN,
99
+ "auth_provider": CHROMADB_AUTH_PROVIDER,
100
+ "auth_header_name": CHROMADB_AUTH_HEADER,
101
+ "collection_settings": {
102
+ "hnsw:space": "cosine",
103
+ "hnsw:construction_ef": 128,
104
+ "hnsw:search_ef": 128,
105
+ "hnsw:M": 16,
106
+ "hnsw:batch_size": 100,
107
+ "hnsw:sync_threshold": 1000,
108
+ },
109
+ },
110
+ )
111
 
112
 
113
  # Run the initialization
lightrag/kg/chroma_impl.py CHANGED
@@ -3,7 +3,7 @@ import asyncio
3
  from dataclasses import dataclass
4
  from typing import Union
5
  import numpy as np
6
- from chromadb import HttpClient
7
  from chromadb.config import Settings
8
  from lightrag.base import BaseVectorStorage
9
  from lightrag.utils import logger
@@ -48,31 +48,41 @@ class ChromaVectorDBStorage(BaseVectorStorage):
48
  **user_collection_settings,
49
  }
50
 
51
- auth_provider = config.get(
52
- "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
53
- )
54
- auth_credentials = config.get("auth_token", "secret-token")
55
- headers = {}
56
-
57
- if "token_authn" in auth_provider:
58
- headers = {
59
- config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
60
- }
61
- elif "basic_authn" in auth_provider:
62
- auth_credentials = config.get("auth_credentials", "admin:admin")
63
-
64
- self._client = HttpClient(
65
- host=config.get("host", "localhost"),
66
- port=config.get("port", 8000),
67
- headers=headers,
68
- settings=Settings(
69
- chroma_api_impl="rest",
70
- chroma_client_auth_provider=auth_provider,
71
- chroma_client_auth_credentials=auth_credentials,
72
- allow_reset=True,
73
- anonymized_telemetry=False,
74
- ),
75
- )
 
 
 
 
 
 
 
 
 
 
76
 
77
  self._collection = self._client.get_or_create_collection(
78
  name=self.namespace,
@@ -143,7 +153,7 @@ class ChromaVectorDBStorage(BaseVectorStorage):
143
  embedding = await self.embedding_func([query])
144
 
145
  results = self._collection.query(
146
- query_embeddings=embedding.tolist(),
147
  n_results=top_k * 2, # Request more results to allow for filtering
148
  include=["metadatas", "distances", "documents"],
149
  )
 
3
  from dataclasses import dataclass
4
  from typing import Union
5
  import numpy as np
6
+ from chromadb import HttpClient, PersistentClient
7
  from chromadb.config import Settings
8
  from lightrag.base import BaseVectorStorage
9
  from lightrag.utils import logger
 
48
  **user_collection_settings,
49
  }
50
 
51
+ local_path = config.get("local_path", None)
52
+ if local_path:
53
+ self._client = PersistentClient(
54
+ path=local_path,
55
+ settings=Settings(
56
+ allow_reset=True,
57
+ anonymized_telemetry=False,
58
+ ),
59
+ )
60
+ else:
61
+ auth_provider = config.get(
62
+ "auth_provider", "chromadb.auth.token_authn.TokenAuthClientProvider"
63
+ )
64
+ auth_credentials = config.get("auth_token", "secret-token")
65
+ headers = {}
66
+
67
+ if "token_authn" in auth_provider:
68
+ headers = {
69
+ config.get("auth_header_name", "X-Chroma-Token"): auth_credentials
70
+ }
71
+ elif "basic_authn" in auth_provider:
72
+ auth_credentials = config.get("auth_credentials", "admin:admin")
73
+
74
+ self._client = HttpClient(
75
+ host=config.get("host", "localhost"),
76
+ port=config.get("port", 8000),
77
+ headers=headers,
78
+ settings=Settings(
79
+ chroma_api_impl="rest",
80
+ chroma_client_auth_provider=auth_provider,
81
+ chroma_client_auth_credentials=auth_credentials,
82
+ allow_reset=True,
83
+ anonymized_telemetry=False,
84
+ ),
85
+ )
86
 
87
  self._collection = self._client.get_or_create_collection(
88
  name=self.namespace,
 
153
  embedding = await self.embedding_func([query])
154
 
155
  results = self._collection.query(
156
+ query_embeddings=embedding.tolist() if not isinstance(embedding, list) else embedding,
157
  n_results=top_k * 2, # Request more results to allow for filtering
158
  include=["metadatas", "distances", "documents"],
159
  )