File size: 5,158 Bytes
5a3ba21
 
8b3b01c
 
 
 
5a3ba21
a25342b
 
0553d6a
a25342b
5a3ba21
 
8b3b01c
5a3ba21
a25342b
5a3ba21
 
8b3b01c
5a3ba21
 
 
 
 
a25342b
5a3ba21
a25342b
5a3ba21
 
 
 
a25342b
 
5a3ba21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a25342b
 
5a3ba21
a25342b
 
5a3ba21
0553d6a
5a3ba21
a25342b
5a3ba21
 
 
a25342b
 
 
5a3ba21
 
a25342b
5a3ba21
0553d6a
5a3ba21
a25342b
5a3ba21
 
 
a25342b
 
 
5a3ba21
 
a25342b
 
5a3ba21
 
a25342b
5a3ba21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275e33e
8b3b01c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5a3ba21
8b3b01c
 
275e33e
8b3b01c
275e33e
 
5a3ba21
 
8b3b01c
9f4950c
a25342b
 
5a3ba21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
import os
import asyncio
import nest_asyncio

nest_asyncio.apply()

from lightrag import LightRAG, QueryParam
from lightrag.llm import (
    openai_complete_if_cache,
    nvidia_openai_embed,
)
from lightrag.utils import EmbeddingFunc
import numpy as np
from lightrag.kg.shared_storage import initialize_pipeline_status

# for custom llm_model_func
from lightrag.utils import locate_json_string_body_from_string


WORKING_DIR = "./dickens"

if not os.path.exists(WORKING_DIR):
    os.mkdir(WORKING_DIR)

# some method to use your API key (choose one)
# NVIDIA_OPENAI_API_KEY = os.getenv("NVIDIA_OPENAI_API_KEY")
NVIDIA_OPENAI_API_KEY = "nvapi-xxxx"  # your api key

# using pre-defined function for nvidia LLM API. OpenAI compatible
# llm_model_func = nvidia_openai_complete


# If you trying to make custom llm_model_func to use llm model on NVIDIA API like other example:
async def llm_model_func(
    prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
) -> str:
    result = await openai_complete_if_cache(
        "nvidia/llama-3.1-nemotron-70b-instruct",
        prompt,
        system_prompt=system_prompt,
        history_messages=history_messages,
        api_key=NVIDIA_OPENAI_API_KEY,
        base_url="https://integrate.api.nvidia.com/v1",
        **kwargs,
    )
    if keyword_extraction:
        return locate_json_string_body_from_string(result)
    return result


# custom embedding
nvidia_embed_model = "nvidia/nv-embedqa-e5-v5"


async def indexing_embedding_func(texts: list[str]) -> np.ndarray:
    return await nvidia_openai_embed(
        texts,
        model=nvidia_embed_model,  # maximum 512 token
        # model="nvidia/llama-3.2-nv-embedqa-1b-v1",
        api_key=NVIDIA_OPENAI_API_KEY,
        base_url="https://integrate.api.nvidia.com/v1",
        input_type="passage",
        trunc="END",  # handling on server side if input token is longer than maximum token
        encode="float",
    )


async def query_embedding_func(texts: list[str]) -> np.ndarray:
    return await nvidia_openai_embed(
        texts,
        model=nvidia_embed_model,  # maximum 512 token
        # model="nvidia/llama-3.2-nv-embedqa-1b-v1",
        api_key=NVIDIA_OPENAI_API_KEY,
        base_url="https://integrate.api.nvidia.com/v1",
        input_type="query",
        trunc="END",  # handling on server side if input token is longer than maximum token
        encode="float",
    )


# dimension are same
async def get_embedding_dim():
    test_text = ["This is a test sentence."]
    embedding = await indexing_embedding_func(test_text)
    embedding_dim = embedding.shape[1]
    return embedding_dim


# function test
async def test_funcs():
    result = await llm_model_func("How are you?")
    print("llm_model_func: ", result)

    result = await indexing_embedding_func(["How are you?"])
    print("embedding_func: ", result)


# asyncio.run(test_funcs())


async def initialize_rag():
    embedding_dimension = await get_embedding_dim()
    print(f"Detected embedding dimension: {embedding_dimension}")

    # lightRAG class during indexing
    rag = LightRAG(
        working_dir=WORKING_DIR,
        llm_model_func=llm_model_func,
        # llm_model_name="meta/llama3-70b-instruct", #un comment if
        embedding_func=EmbeddingFunc(
            embedding_dim=embedding_dimension,
            max_token_size=512,  # maximum token size, somehow it's still exceed maximum number of token
            # so truncate (trunc) parameter on embedding_func will handle it and try to examine the tokenizer used in LightRAG
            # so you can adjust to be able to fit the NVIDIA model (future work)
            func=indexing_embedding_func,
        ),
    )

    await rag.initialize_storages()
    await initialize_pipeline_status()

    return rag


async def main():
    try:
        # Initialize RAG instance
        rag = await initialize_rag()

        # reading file
        with open("./book.txt", "r", encoding="utf-8") as f:
            await rag.ainsert(f.read())

        # Perform naive search
        print("==============Naive===============")
        print(
            await rag.aquery(
                "What are the top themes in this story?", param=QueryParam(mode="naive")
            )
        )

        # Perform local search
        print("==============local===============")
        print(
            await rag.aquery(
                "What are the top themes in this story?", param=QueryParam(mode="local")
            )
        )

        # Perform global search
        print("==============global===============")
        print(
            await rag.aquery(
                "What are the top themes in this story?",
                param=QueryParam(mode="global"),
            )
        )

        # Perform hybrid search
        print("==============hybrid===============")
        print(
            await rag.aquery(
                "What are the top themes in this story?",
                param=QueryParam(mode="hybrid"),
            )
        )
    except Exception as e:
        print(f"An error occurred: {e}")


if __name__ == "__main__":
    asyncio.run(main())