File size: 5,044 Bytes
0fc6305
 
 
0553d6a
0fc6305
 
8b3b01c
0fc6305
 
 
 
 
 
 
 
 
 
 
bef4c6d
 
7149d85
 
 
bef4c6d
0fc6305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0553d6a
0fc6305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bef4c6d
8b3b01c
bef4c6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fc6305
bef4c6d
 
8b3b01c
bef4c6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fc6305
8b3b01c
 
0fc6305
8b3b01c
0fc6305
275e33e
 
8b3b01c
 
0fc6305
8b3b01c
 
0fc6305
8b3b01c
 
275e33e
 
 
8b3b01c
0fc6305
8b3b01c
 
275e33e
 
 
8b3b01c
 
 
 
275e33e
 
 
8b3b01c
 
 
 
275e33e
 
 
8b3b01c
 
275e33e
8b3b01c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.llm.openai import gpt_4o_mini_complete, openai_embed
from lightrag.utils import EmbeddingFunc
import numpy as np
from lightrag.kg.shared_storage import initialize_pipeline_status

#########
# Uncomment the below two lines if running in a jupyter notebook to handle the async nature of rag.insert()
# import nest_asyncio
# nest_asyncio.apply()
#########
WORKING_DIR = "./chromadb_test_dir"
if not os.path.exists(WORKING_DIR):
    os.mkdir(WORKING_DIR)

# ChromaDB Configuration
CHROMADB_USE_LOCAL_PERSISTENT = False
# Local PersistentClient Configuration
CHROMADB_LOCAL_PATH = os.environ.get(
    "CHROMADB_LOCAL_PATH", os.path.join(WORKING_DIR, "chroma_data")
)
# Remote HttpClient Configuration
CHROMADB_HOST = os.environ.get("CHROMADB_HOST", "localhost")
CHROMADB_PORT = int(os.environ.get("CHROMADB_PORT", 8000))
CHROMADB_AUTH_TOKEN = os.environ.get("CHROMADB_AUTH_TOKEN", "secret-token")
CHROMADB_AUTH_PROVIDER = os.environ.get(
    "CHROMADB_AUTH_PROVIDER", "chromadb.auth.token_authn.TokenAuthClientProvider"
)
CHROMADB_AUTH_HEADER = os.environ.get("CHROMADB_AUTH_HEADER", "X-Chroma-Token")

# Embedding Configuration and Functions
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "text-embedding-3-large")
EMBEDDING_MAX_TOKEN_SIZE = int(os.environ.get("EMBEDDING_MAX_TOKEN_SIZE", 8192))

# ChromaDB requires knowing the dimension of embeddings upfront when
# creating a collection. The embedding dimension is model-specific
# (e.g. text-embedding-3-large uses 3072 dimensions)
# we dynamically determine it by running a test embedding
# and then pass it to the ChromaDBStorage class


async def embedding_func(texts: list[str]) -> np.ndarray:
    return await openai_embed(
        texts,
        model=EMBEDDING_MODEL,
    )


async def get_embedding_dimension():
    test_text = ["This is a test sentence."]
    embedding = await embedding_func(test_text)
    return embedding.shape[1]


async def create_embedding_function_instance():
    # Get embedding dimension
    embedding_dimension = await get_embedding_dimension()
    # Create embedding function instance
    return EmbeddingFunc(
        embedding_dim=embedding_dimension,
        max_token_size=EMBEDDING_MAX_TOKEN_SIZE,
        func=embedding_func,
    )


async def initialize_rag():
    embedding_func_instance = await create_embedding_function_instance()
    if CHROMADB_USE_LOCAL_PERSISTENT:
        rag = LightRAG(
            working_dir=WORKING_DIR,
            llm_model_func=gpt_4o_mini_complete,
            embedding_func=embedding_func_instance,
            vector_storage="ChromaVectorDBStorage",
            log_level="DEBUG",
            embedding_batch_num=32,
            vector_db_storage_cls_kwargs={
                "local_path": CHROMADB_LOCAL_PATH,
                "collection_settings": {
                    "hnsw:space": "cosine",
                    "hnsw:construction_ef": 128,
                    "hnsw:search_ef": 128,
                    "hnsw:M": 16,
                    "hnsw:batch_size": 100,
                    "hnsw:sync_threshold": 1000,
                },
            },
        )
    else:
        rag = LightRAG(
            working_dir=WORKING_DIR,
            llm_model_func=gpt_4o_mini_complete,
            embedding_func=embedding_func_instance,
            vector_storage="ChromaVectorDBStorage",
            log_level="DEBUG",
            embedding_batch_num=32,
            vector_db_storage_cls_kwargs={
                "host": CHROMADB_HOST,
                "port": CHROMADB_PORT,
                "auth_token": CHROMADB_AUTH_TOKEN,
                "auth_provider": CHROMADB_AUTH_PROVIDER,
                "auth_header_name": CHROMADB_AUTH_HEADER,
                "collection_settings": {
                    "hnsw:space": "cosine",
                    "hnsw:construction_ef": 128,
                    "hnsw:search_ef": 128,
                    "hnsw:M": 16,
                    "hnsw:batch_size": 100,
                    "hnsw:sync_threshold": 1000,
                },
            },
        )

    await rag.initialize_storages()
    await initialize_pipeline_status()

    return rag


def main():
    # Initialize RAG instance
    rag = asyncio.run(initialize_rag())

    with open("./book.txt", "r", encoding="utf-8") as f:
        rag.insert(f.read())

    # Perform naive search
    print(
        rag.query(
            "What are the top themes in this story?", param=QueryParam(mode="naive")
        )
    )

    # Perform local search
    print(
        rag.query(
            "What are the top themes in this story?", param=QueryParam(mode="local")
        )
    )

    # Perform global search
    print(
        rag.query(
            "What are the top themes in this story?", param=QueryParam(mode="global")
        )
    )

    # Perform hybrid search
    print(
        rag.query(
            "What are the top themes in this story?", param=QueryParam(mode="hybrid")
        )
    )


if __name__ == "__main__":
    main()