File size: 3,302 Bytes
359e407
 
 
 
 
 
 
029b1f6
 
359e407
8b3b01c
359e407
 
 
 
 
 
 
 
 
 
 
 
 
c832152
 
 
 
 
359e407
275e33e
8b3b01c
359e407
 
 
 
 
 
b6db833
359e407
552eb31
359e407
 
552eb31
359e407
 
 
 
 
d19a515
552eb31
359e407
c832152
8b3b01c
 
 
 
 
275e33e
8b3b01c
275e33e
9f4950c
8b3b01c
359e407
 
 
 
 
 
 
 
 
 
 
d19a515
 
 
359e407
 
 
 
 
 
d19a515
 
 
359e407
 
 
 
 
 
d19a515
 
 
359e407
 
 
 
 
d19a515
 
 
359e407
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import asyncio
import logging
import os
import time
from dotenv import load_dotenv

from lightrag import LightRAG, QueryParam
from lightrag.llm.zhipu import zhipu_complete
from lightrag.llm.ollama import ollama_embedding
from lightrag.utils import EmbeddingFunc
from lightrag.kg.shared_storage import initialize_pipeline_status

load_dotenv()
ROOT_DIR = os.environ.get("ROOT_DIR")
WORKING_DIR = f"{ROOT_DIR}/dickens-pg"

logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)

if not os.path.exists(WORKING_DIR):
    os.mkdir(WORKING_DIR)

# AGE
os.environ["AGE_GRAPH_NAME"] = "dickens"

os.environ["POSTGRES_HOST"] = "localhost"
os.environ["POSTGRES_PORT"] = "15432"
os.environ["POSTGRES_USER"] = "rag"
os.environ["POSTGRES_PASSWORD"] = "rag"
os.environ["POSTGRES_DATABASE"] = "rag"


async def initialize_rag():
    rag = LightRAG(
        working_dir=WORKING_DIR,
        llm_model_func=zhipu_complete,
        llm_model_name="glm-4-flashx",
        llm_model_max_async=4,
        llm_model_max_token_size=32768,
        enable_llm_cache_for_entity_extract=True,
        embedding_func=EmbeddingFunc(
            embedding_dim=1024,
            max_token_size=8192,
            func=lambda texts: ollama_embedding(
                texts, embed_model="bge-m3", host="http://localhost:11434"
            ),
        ),
        kv_storage="PGKVStorage",
        doc_status_storage="PGDocStatusStorage",
        graph_storage="PGGraphStorage",
        vector_storage="PGVectorStorage",
        auto_manage_storages_states=False,
    )

    await rag.initialize_storages()
    await initialize_pipeline_status()

    return rag


async def main():
    # Initialize RAG instance
    rag = await initialize_rag()

    # add embedding_func for graph database, it's deleted in commit 5661d76860436f7bf5aef2e50d9ee4a59660146c
    rag.chunk_entity_relation_graph.embedding_func = rag.embedding_func

    with open(f"{ROOT_DIR}/book.txt", "r", encoding="utf-8") as f:
        await rag.ainsert(f.read())

    print("==== Trying to test the rag queries ====")
    print("**** Start Naive Query ****")
    start_time = time.time()
    # Perform naive search
    print(
        await rag.aquery(
            "What are the top themes in this story?", param=QueryParam(mode="naive")
        )
    )
    print(f"Naive Query Time: {time.time() - start_time} seconds")
    # Perform local search
    print("**** Start Local Query ****")
    start_time = time.time()
    print(
        await rag.aquery(
            "What are the top themes in this story?", param=QueryParam(mode="local")
        )
    )
    print(f"Local Query Time: {time.time() - start_time} seconds")
    # Perform global search
    print("**** Start Global Query ****")
    start_time = time.time()
    print(
        await rag.aquery(
            "What are the top themes in this story?", param=QueryParam(mode="global")
        )
    )
    print(f"Global Query Time: {time.time() - start_time}")
    # Perform hybrid search
    print("**** Start Hybrid Query ****")
    print(
        await rag.aquery(
            "What are the top themes in this story?", param=QueryParam(mode="hybrid")
        )
    )
    print(f"Hybrid Query Time: {time.time() - start_time} seconds")


if __name__ == "__main__":
    asyncio.run(main())