File size: 3,084 Bytes
f5452a1
 
 
0553d6a
 
f5452a1
 
8b3b01c
 
 
 
 
 
f5452a1
 
 
 
 
 
7b0dab2
f5452a1
fad5373
 
 
 
 
f5452a1
7b0dab2
f5452a1
 
 
 
 
 
 
 
 
7b0dab2
f5452a1
 
 
 
 
275e33e
8b3b01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5452a1
 
8b3b01c
f5452a1
8b3b01c
 
275e33e
8b3b01c
f5452a1
275e33e
8b3b01c
 
 
f5452a1
8b3b01c
 
 
f5452a1
8b3b01c
 
 
275e33e
 
 
8b3b01c
f5452a1
8b3b01c
 
275e33e
 
 
8b3b01c
 
 
 
275e33e
 
 
8b3b01c
 
 
 
275e33e
 
 
8b3b01c
f5452a1
275e33e
8b3b01c
275e33e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os

from lightrag import LightRAG, QueryParam
from lightrag.llm.lmdeploy import lmdeploy_model_if_cache
from lightrag.llm.hf import hf_embed
from lightrag.utils import EmbeddingFunc
from transformers import AutoModel, AutoTokenizer
from lightrag.kg.shared_storage import initialize_pipeline_status

import asyncio
import nest_asyncio

nest_asyncio.apply()

WORKING_DIR = "./dickens"

if not os.path.exists(WORKING_DIR):
    os.mkdir(WORKING_DIR)


async def lmdeploy_model_complete(
    prompt=None,
    system_prompt=None,
    history_messages=[],
    keyword_extraction=False,
    **kwargs,
) -> str:
    model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
    return await lmdeploy_model_if_cache(
        model_name,
        prompt,
        system_prompt=system_prompt,
        history_messages=history_messages,
        ## please specify chat_template if your local path does not follow original HF file name,
        ## or model_name is a pytorch model on huggingface.co,
        ## you can refer to https://github.com/InternLM/lmdeploy/blob/main/lmdeploy/model.py
        ## for a list of chat_template available in lmdeploy.
        chat_template="llama3",
        # model_format ='awq', # if you are using awq quantization model.
        # quant_policy=8, # if you want to use online kv cache, 4=kv int4, 8=kv int8.
        **kwargs,
    )


async def initialize_rag():
    rag = LightRAG(
        working_dir=WORKING_DIR,
        llm_model_func=lmdeploy_model_complete,
        llm_model_name="meta-llama/Llama-3.1-8B-Instruct",  # please use definite path for local model
        embedding_func=EmbeddingFunc(
            embedding_dim=384,
            max_token_size=5000,
            func=lambda texts: hf_embed(
                texts,
                tokenizer=AutoTokenizer.from_pretrained(
                    "sentence-transformers/all-MiniLM-L6-v2"
                ),
                embed_model=AutoModel.from_pretrained(
                    "sentence-transformers/all-MiniLM-L6-v2"
                ),
            ),
        ),
    )

    await rag.initialize_storages()
    await initialize_pipeline_status()

    return rag


def main():
    # Initialize RAG instance
    rag = asyncio.run(initialize_rag())

    # Insert example text
    with open("./book.txt", "r", encoding="utf-8") as f:
        rag.insert(f.read())

    # Test different query modes
    print("\nNaive Search:")
    print(
        rag.query(
            "What are the top themes in this story?", param=QueryParam(mode="naive")
        )
    )

    print("\nLocal Search:")
    print(
        rag.query(
            "What are the top themes in this story?", param=QueryParam(mode="local")
        )
    )

    print("\nGlobal Search:")
    print(
        rag.query(
            "What are the top themes in this story?", param=QueryParam(mode="global")
        )
    )

    print("\nHybrid Search:")
    print(
        rag.query(
            "What are the top themes in this story?", param=QueryParam(mode="hybrid")
        )
    )


if __name__ == "__main__":
    main()