yangdx commited on
Commit
e911ecd
·
1 Parent(s): bd2684f

Fix cosine threshold parameter setting error for chroma

Browse files
lightrag/kg/chroma_impl.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import asyncio
2
  from dataclasses import dataclass
3
  from typing import Union
@@ -12,16 +13,16 @@ from lightrag.utils import logger
12
  class ChromaVectorDBStorage(BaseVectorStorage):
13
  """ChromaDB vector storage implementation."""
14
 
15
- cosine_better_than_threshold: float = 0.2
16
 
17
  def __post_init__(self):
18
  try:
19
  # Use global config value if specified, otherwise use default
20
- self.cosine_better_than_threshold = self.global_config.get(
 
21
  "cosine_better_than_threshold", self.cosine_better_than_threshold
22
  )
23
 
24
- config = self.global_config.get("vector_db_storage_cls_kwargs", {})
25
  user_collection_settings = config.get("collection_settings", {})
26
  # Default HNSW index settings for ChromaDB
27
  default_collection_settings = {
 
1
+ import os
2
  import asyncio
3
  from dataclasses import dataclass
4
  from typing import Union
 
13
  class ChromaVectorDBStorage(BaseVectorStorage):
14
  """ChromaDB vector storage implementation."""
15
 
16
+ cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
17
 
18
  def __post_init__(self):
19
  try:
20
  # Use global config value if specified, otherwise use default
21
+ config = self.global_config.get("vector_db_storage_cls_kwargs", {})
22
+ self.cosine_better_than_threshold = config.get(
23
  "cosine_better_than_threshold", self.cosine_better_than_threshold
24
  )
25
 
 
26
  user_collection_settings = config.get("collection_settings", {})
27
  # Default HNSW index settings for ChromaDB
28
  default_collection_settings = {
lightrag/kg/nano_vector_db_impl.py CHANGED
@@ -76,6 +76,12 @@ class NanoVectorDBStorage(BaseVectorStorage):
76
  cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
77
 
78
  def __post_init__(self):
 
 
 
 
 
 
79
  self._client_file_name = os.path.join(
80
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
81
  )
@@ -83,14 +89,6 @@ class NanoVectorDBStorage(BaseVectorStorage):
83
  self._client = NanoVectorDB(
84
  self.embedding_func.embedding_dim, storage_file=self._client_file_name
85
  )
86
- # get cosine_better_than_threshold from LightRAG
87
- vector_db_kwargs = self.global_config.get("vector_db_storage_cls_kwargs", {})
88
- self.cosine_better_than_threshold = vector_db_kwargs.get(
89
- "cosine_better_than_threshold",
90
- self.global_config.get(
91
- "cosine_better_than_threshold", self.cosine_better_than_threshold
92
- ),
93
- )
94
 
95
  async def upsert(self, data: dict[str, dict]):
96
  logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
 
76
  cosine_better_than_threshold: float = float(os.getenv("COSINE_THRESHOLD", "0.2"))
77
 
78
  def __post_init__(self):
79
+ # Use global config value if specified, otherwise use default
80
+ config = self.global_config.get("vector_db_storage_cls_kwargs", {})
81
+ self.cosine_better_than_threshold = config.get(
82
+ "cosine_better_than_threshold", self.cosine_better_than_threshold
83
+ )
84
+
85
  self._client_file_name = os.path.join(
86
  self.global_config["working_dir"], f"vdb_{self.namespace}.json"
87
  )
 
89
  self._client = NanoVectorDB(
90
  self.embedding_func.embedding_dim, storage_file=self._client_file_name
91
  )
 
 
 
 
 
 
 
 
92
 
93
  async def upsert(self, data: dict[str, dict]):
94
  logger.info(f"Inserting {len(data)} vectors to {self.namespace}")