File size: 5,044 Bytes
0fc6305 0553d6a 0fc6305 8b3b01c 0fc6305 bef4c6d 7149d85 bef4c6d 0fc6305 0553d6a 0fc6305 bef4c6d 8b3b01c bef4c6d 0fc6305 bef4c6d 8b3b01c bef4c6d 0fc6305 8b3b01c 0fc6305 8b3b01c 0fc6305 275e33e 8b3b01c 0fc6305 8b3b01c 0fc6305 8b3b01c 275e33e 8b3b01c 0fc6305 8b3b01c 275e33e 8b3b01c 275e33e 8b3b01c 275e33e 8b3b01c 275e33e 8b3b01c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from lightrag.kg.shared_storage import initialize_pipeline_status
#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./chromadb_test_dir"
if not os.path.exists(WORKING_DIR):
os.mkdir(WORKING_DIR)
# ChromaDB Configuration
CHROMADB_USE_LOCAL_PERSISTENT = False
# Local PersistentClient Configuration
CHROMADB_LOCAL_PATH = os.environ.get(
"CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")
)
# Remote HttpClient Configuration
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
CHROMADB_AUTH_PROVIDER = os.environ.get(
"CHROMADB_AUTH_PROVIDER", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
CHROMADB_AUTH_HEADER = os.environ.get("CHROMADB_AUTH_HEADER", "X-Chroma-Token")
# Embedding Configuration and Functions
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))
# ChromaDB requires knowing the dimension of embeddings upfront when
# creating a collection. The embedding dimension is model-specific
# (e.g. text-embedding-3-large uses 3072 dimensions)
# we dynamically determine it by running a test embedding
# and then pass it to the ChromaDBStorage class
async def embedding_func(texts: list[str]) -> np.ndarray:
return await openai_embed(
texts,
model=EMBEDDING_MODEL,
)
async def get_embedding_dimension():
test_text = ["This is a test sentence."]
embedding = await embedding_func(test_text)
return embedding.shape[1]
async def create_embedding_function_instance():
# Get embedding dimension
embedding_dimension = await get_embedding_dimension()
# Create embedding function instance
return EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
func=embedding_func,
)
async def initialize_rag():
embedding_func_instance = await create_embedding_function_instance()
if CHROMADB_USE_LOCAL_PERSISTENT:
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"local_path": CHROMADB_LOCAL_PATH,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
},
},
)
else:
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=gpt_4o_mini_complete,
embedding_func=embedding_func_instance,
vector_storage="ChromaVectorDBStorage",
log_level="DEBUG",
embedding_batch_num=32,
vector_db_storage_cls_kwargs={
"host": CHROMADB_HOST,
"port": CHROMADB_PORT,
"auth_token": CHROMADB_AUTH_TOKEN,
"auth_provider": CHROMADB_AUTH_PROVIDER,
"auth_header_name": CHROMADB_AUTH_HEADER,
"collection_settings": {
"hnsw:space": "cosine",
"hnsw:construction_ef": 128,
"hnsw:search_ef": 128,
"hnsw:M": 16,
"hnsw:batch_size": 100,
"hnsw:sync_threshold": 1000,
},
},
)
await rag.initialize_storages()
await initialize_pipeline_status()
return rag
def main():
# Initialize RAG instance
rag = asyncio.run(initialize_rag())
with open("./book.txt", "r", encoding="utf-8") as f:
rag.insert(f.read())
# Perform naive search
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="naive")
)
)
# Perform local search
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="local")
)
)
# Perform global search
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="global")
)
)
# Perform hybrid search
print(
rag.query(
"What are the top themes in this story?", param=QueryParam(mode="hybrid")
)
)
if __name__ == "__main__":
main()
|